From f2f64cfe8d8d732a3f464de31ecff7427901b6cb Mon Sep 17 00:00:00 2001 From: Vladislav Vinogradov Date: Fri, 5 Feb 2021 16:53:00 +0300 Subject: [PATCH 001/915] [mlir] Model MemRef memory space as Attribute Based on the following discussion: https://llvm.discourse.group/t/rfc-memref-memory-shape-as-attribute/2229 The goal of the change is to make memory space property to have more expressive representation, rather then "magic" integer values. It will allow to have more clean ASM form: ``` gpu.func @test(%arg0: memref<100xf32, "workgroup">) // instead of gpu.func @test(%arg0: memref<100xf32, 3>) ``` Explanation for `Attribute` choice instead of plain `string`: * `Attribute` classes allow to use more type safe API based on RTTI. * `Attribute` classes provides faster comparison operator based on pointer comparison in contrast to generic string comparison. * `Attribute` allows to store more complex things, like structs or dictionaries. It will allows to have more complex memory space hierarchy. This commit preserve old integer-based API and implements it on top of the new one. Depends on D97476 Reviewed By: rriddle, mehdi_amini Differential Revision: https://reviews.llvm.org/D96145 --- mlir/include/mlir-c/BuiltinTypes.h | 25 ++++++++++--------- mlir/lib/Bindings/Python/IRModules.cpp | 28 ++++++++++++++------- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 34 ++++++++++++++------------ 3 files changed, 50 insertions(+), 37 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index a706c58ef..b2ec37c9d 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -224,38 +224,38 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); /// same context as element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, unsigned memorySpace); + MlirAffineMap const *affineMaps, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, unsigned memorySpace); + intptr_t numMaps, MlirAffineMap const *affineMaps, + MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, /// i.e. represents a default row-major contiguous memref. The type is owned by /// the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGet(MlirType elementType, - intptr_t rank, - const int64_t *shape, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, + const int64_t *shape, MlirAttribute memorySpace); /// Same as "mlirMemRefTypeContiguousGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeContiguousGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace); + MlirAttribute memorySpace); /// Creates an Unranked MemRef type with the given element type and in the given /// memory space. The type is owned by the context of element type. -MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, - unsigned memorySpace); +MLIR_CAPI_EXPORTED MlirType +mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); /// Same as "mlirUnrankedMemRefTypeGet" but returns a nullptr wrapping /// MlirType on illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( - MlirLocation loc, MlirType elementType, unsigned memorySpace); + MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); /// Returns the number of affine layout maps in the given MemRef type. MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); @@ -265,10 +265,11 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos); /// Returns the memory space of the given MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirMemRefTypeGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); /// Returns the memory spcae of the given Unranked MemRef type. -MLIR_CAPI_EXPORTED unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnrankedMemrefGetMemorySpace(MlirType type); //===----------------------------------------------------------------------===// // Tuple type. diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 9152fd06d..a544e52c2 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2861,16 +2861,20 @@ class PyMemRefType : public PyConcreteType { c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, unsigned memorySpace, + std::vector layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { SmallVector maps; maps.reserve(layout.size()); for (PyAffineMap &map : layout) maps.push_back(map); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), maps.size(), - maps.data(), memorySpace); + maps.data(), memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2885,14 +2889,15 @@ class PyMemRefType : public PyConcreteType { return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = 0, + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly("layout", &PyMemRefType::getLayout, "The list of layout maps of the MemRef type.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> unsigned { - return mlirMemRefTypeGetMemorySpace(self); + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given MemRef type."); } @@ -2944,10 +2949,14 @@ class PyUnrankedMemRefType static void bindDerived(ClassTy &c) { c.def_static( "get", - [](PyType &elementType, unsigned memorySpace, + [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memorySpace); + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -2965,8 +2974,9 @@ class PyUnrankedMemRefType py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> unsigned { - return mlirUnrankedMemrefGetMemorySpace(self); + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); }, "Returns the memory space of the given Unranked MemRef type."); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index e4442ac4c..c84ced177 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -223,41 +223,41 @@ bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, MlirAffineMap const *affineMaps, - unsigned memorySpace) { + MlirAttribute memorySpace) { SmallVector maps; (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, memorySpace)); + unwrap(elementType), maps, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap( MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, memorySpace)); + unwrap(elementType), llvm::None, unwrap(memorySpace))); } intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { @@ -269,27 +269,29 @@ MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { return wrap(unwrap(type).cast().getAffineMaps()[pos]); } -unsigned mlirMemRefTypeGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { return unwrap(type).isa(); } -MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, unsigned memorySpace) { - return wrap(UnrankedMemRefType::get(unwrap(elementType), memorySpace)); +MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, + MlirAttribute memorySpace) { + return wrap( + UnrankedMemRefType::get(unwrap(elementType), unwrap(memorySpace))); } MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, - unsigned memorySpace) { + MlirAttribute memorySpace) { return wrap(UnrankedMemRefType::getChecked(unwrap(loc), unwrap(elementType), - memorySpace)); + unwrap(memorySpace))); } -unsigned mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return unwrap(type).cast().getMemorySpaceAsInt(); +MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { + return wrap(unwrap(type).cast().getMemorySpace()); } //===----------------------------------------------------------------------===// From 9569fe481edaee279bc05d98f0b9b12a6efe0024 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 15 Mar 2021 14:06:25 +0100 Subject: [PATCH 002/915] [mlir] enable Python bindings for the MemRef dialect A previous commit moved multiple ops from Standard to MemRef dialect. Some of these ops are exercised in Python bindings. Enable bindings for the newly created MemRef dialect and update a test accordingly. --- mlir/lib/Bindings/Python/CMakeLists.txt | 5 +++++ mlir/lib/Bindings/Python/MemRefOps.td | 15 +++++++++++++++ mlir/lib/Bindings/Python/mlir/dialects/memref.py | 5 +++++ 3 files changed, 25 insertions(+) create mode 100644 mlir/lib/Bindings/Python/MemRefOps.td create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/memref.py diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index c444ddcc4..5f042ec57 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -42,6 +42,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps DEPENDS LinalgOdsGen) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps + TD_FILE MemRefOps.td + DIALECT_NAME memref) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMemRefOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonShapeOps TD_FILE ShapeOps.td DIALECT_NAME shape) diff --git a/mlir/lib/Bindings/Python/MemRefOps.td b/mlir/lib/Bindings/Python/MemRefOps.td new file mode 100644 index 000000000..8dd976479 --- /dev/null +++ b/mlir/lib/Bindings/Python/MemRefOps.td @@ -0,0 +1,15 @@ +//===-- MemRefOps.td - Entry point for MemRefOps bind ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MEMREF_OPS +#define PYTHON_BINDINGS_MEMREF_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/MemRef/IR/MemRefOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/mlir/dialects/memref.py b/mlir/lib/Bindings/Python/mlir/dialects/memref.py new file mode 100644 index 000000000..3afb6a70c --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/memref.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._memref_ops_gen import * From f62c6f81052e01e39976cb20de6526fa61729d63 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 16 Mar 2021 18:05:19 -0700 Subject: [PATCH 003/915] [mlir][linalg] Add structured op builders from python opdsl. * Makes the wrapped functions of the `@linalg_structured_op` decorator callable such that they emit IR imperatively when invoked. * There are numerous TODOs that I will keep working through to achieve generality. * Will true up exception handling tests as the feature progresses (for things that are actually errors once everything is implemented). * Includes the addition of an `isinstance` method on concrete types in the Python API. Differential Revision: https://reviews.llvm.org/D98754 --- mlir/lib/Bindings/Python/IRModules.cpp | 3 + .../mlir/dialects/linalg/opdsl/lang/config.py | 17 +- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 33 ++- .../dialects/linalg/opdsl/lang/emitter.py | 252 ++++++++++++++++++ 4 files changed, 294 insertions(+), 11 deletions(-) create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index a544e52c2..6b4e5434d 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -2477,6 +2477,9 @@ class PyConcreteType : public BaseTy { static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName); cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py index 115ea4061..fdc6cfd9b 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py @@ -21,6 +21,7 @@ __all__ = [ "LinalgStructuredOpConfig", "LinalgOpConfig", + "TensorDefConfig", ] @@ -51,17 +52,17 @@ def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): self.shape_map = shape_map self.indexing_map = None # type: Optional[_ir.AffineMap] - def to_yaml_custom_dict(self): - - def get_usage(): - if self.tensor_def.output: - return "output" - else: - return "input" + @property + def usage(self) -> str: + if self.tensor_def.output: + return "output" + else: + return "input" + def to_yaml_custom_dict(self): return dict( name=self.tensor_def.tensor_name, - usage=get_usage(), + usage=self.usage, shape=_serialize_affine_map(self.shape_map), element_type_var=self.tensor_def.type_var.name, ) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py index d367c5bdd..cbff41db2 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -11,6 +11,8 @@ from mlir import ir from .comprehension import * +from .config import * +from .emitter import * _CONTEXT = threading.local() @@ -42,9 +44,34 @@ def __init__(self, op_name: str, model: LinalgOpDef): self.op_name = op_name self.model = model - def __call__(self, *args, **kwargs): - # TODO: Upstream the emitter and invoke here - raise NotImplementedError("Linalg generic emission not yet implemented") + def __call__(self, *args, emit_generic: bool = True, **kwargs): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + are interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + op_configs = LinalgOpConfig.from_linalg_op_def(self.model, + context=ir.Context.current) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}") + + op_config = op_configs[0] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op(op_config.structured_op, *args, + **kwargs) + else: + return emit_named_structured_op(op_config.structured_op, *args, + **kwargs) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}") def linalg_structured_op(dsl_func=None, diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py new file mode 100644 index 000000000..9a18993e9 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -0,0 +1,252 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Dict, Sequence + +from mlir.ir import * +from mlir.dialects import linalg +from mlir.dialects import std + +from .scalar_expr import * +from .config import * + +__all__ = [ + "emit_generic_structured_op", + "emit_named_structured_op", +] + + +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + all_arg_defs = op_config.ordered_tensor_args + in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] + out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError(f"Expected {len(in_arg_defs)} inputs but got " + f"{len(ins)} for {op_config}") + if outs and len(outs) != len(out_arg_defs): + raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}") + + outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, + out_arg_defs, outs) + + # Extract type vars for input/output based types. + type_mapping = dict() # type: Dict[str, Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, + _get_shaped_element_types_from_values(*ins, *outs)): + tv_name = arg_def.tensor_def.type_var.name + type_mapping[tv_name] = arg_element_type + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in op_config.indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [StringAttr.get(s) for s in op_config.iterator_types]) + generic_op = linalg.GenericOp( + result_tensors=out_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, # TODO: Make optional. + sparse=BoolAttr.get(False)) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) + block_arg_types = _get_shaped_element_types_from_values(*ins, *outs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) + + if len(out_arg_defs) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + raise NotImplementedError( + f"Emission of named structured ops is not supported: {op_config}") + + +class _BodyBuilder: + """Constructs a structured op body by evaluating assignments.""" + + def __init__(self, type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value]): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}") + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op.") + elif expr.scalar_apply: + try: + fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") + except AttributeError: + raise ValueError( + f"Function '{expr.scalar_apply.fn_name}' is not a known " + "scalar body function") + operand_values = [ + self.expression(operand) for operand in expr.scalar_apply.operands + ] + return fn(*operand_values) + elif expr.symbolic_cast: + operand_value = self.expression(expr.symbolic_cast.operand) + return self.cast(expr.symbolic_cast.to_type.name, operand_value) + raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + + def cast(self, type_var_name: str, operand: Value) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError(f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mappings.keys()}") + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand) + + raise ValueError(f"Unable to cast body expression from {operand.type} to " + f"{to_type}") + + def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + return std.FPToSIOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + return std.SignExtendIOp(to_type, operand).result + elif to_width < from_width: + return std.TruncateIOp(to_type, operand).result + raise ValueError(f"Unable to cast body expression from {operand_type} to " + f"{to_type}") + + def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + return std.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return std.FPExtOp(to_type, operand).result + elif to_width < from_width: + return std.FPTruncOp(to_type, operand).result + raise ValueError(f"Unable to cast body expression from {operand_type} to " + f"{to_type}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError(f"Body assignments do not assign all outputs: " + f"missing '{n}'") + linalg.YieldOp(output_values) + + def _eval_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.AddFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type): + return std.AddIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operand: {lhs}") + + def _eval_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MulFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type): + return std.MulIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operand: {lhs}") + + +def _infer_structured_outs(op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[TensorDefConfig], + ins: Sequence[Value], + out_arg_defs: Sequence[TensorDefConfig], + outs: Sequence[Value]): + """Infers implicit outs and output types. + + Respects existing contents of outs if not empty. + + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] + + raise NotImplementedError(f"Output tensor inference not yet supported for " + "structured ops") + + +def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]: + types = [] + for v in values: + try: + t = ShapedType(v.type) + except Exception as e: + raise ValueError(f"Expected ShapedType but got {v}") from e + types.append(t.element_type) + return types + + +def _get_tensor_def_names( + *tensor_def_configs: TensorDefConfig) -> Sequence[str]: + return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] + + +def _is_floating_point_type(t: Type) -> bool: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return (F64Type.isinstance(t) or F32Type.isinstance(t) or + F16Type.isinstance(t) or BF16Type.isinstance(t)) + + +def _is_integer_type(t: Type) -> bool: + return IntegerType.isinstance(t) + + +def _get_floating_point_width(t: Type) -> int: + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") From 1dc6ab5bb9820a07de078e7186ccb1d84d20c1fe Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 19 Mar 2021 11:57:01 -0700 Subject: [PATCH 004/915] NFC: Break up the mlir python bindings into individual sources. * IRModules.cpp -> (IRCore.cpp, IRAffine.cpp, IRAttributes.cpp, IRTypes.cpp). * The individual pieces now compile in the 5-15s range whereas IRModules.cpp was starting to approach a minute (didn't capture a before time). * More fine grained splitting is possible, but this represents the most obvious. Differential Revision: https://reviews.llvm.org/D98978 --- mlir/lib/Bindings/Python/CMakeLists.txt | 5 +- mlir/lib/Bindings/Python/ExecutionEngine.cpp | 2 +- mlir/lib/Bindings/Python/IRAffine.cpp | 781 ++++++ mlir/lib/Bindings/Python/IRAttributes.cpp | 761 +++++ .../Python/{IRModules.cpp => IRCore.cpp} | 2471 ++--------------- .../Python/{IRModules.h => IRModule.h} | 5 +- mlir/lib/Bindings/Python/IRTypes.cpp | 678 +++++ mlir/lib/Bindings/Python/MainModule.cpp | 7 +- mlir/lib/Bindings/Python/Pass.cpp | 2 +- 9 files changed, 2394 insertions(+), 2318 deletions(-) create mode 100644 mlir/lib/Bindings/Python/IRAffine.cpp create mode 100644 mlir/lib/Bindings/Python/IRAttributes.cpp rename mlir/lib/Bindings/Python/{IRModules.cpp => IRCore.cpp} (52%) rename mlir/lib/Bindings/Python/{IRModules.h => IRModule.h} (99%) create mode 100644 mlir/lib/Bindings/Python/IRTypes.cpp diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 5f042ec57..5fefa8039 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -70,7 +70,10 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir python SOURCES MainModule.cpp - IRModules.cpp + IRAffine.cpp + IRAttributes.cpp + IRCore.cpp + IRTypes.cpp PybindUtils.cpp Pass.cpp ExecutionEngine.cpp diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp index f6f52e2e0..5ca9b1f68 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -8,7 +8,7 @@ #include "ExecutionEngine.h" -#include "IRModules.h" +#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/ExecutionEngine.h" diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp new file mode 100644 index 000000000..73a57d95e --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -0,0 +1,781 @@ +//===- IRAffine.cpp - Exports 'ir' module affine related bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/AffineMap.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/IntegerSet.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +static const char kDumpDocstring[] = + R"(Dumps a debug representation of the object to stderr.)"; + +/// Attempts to populate `result` with the content of `list` casted to the +/// appropriate type (Python and C types are provided as template arguments). +/// Throws errors in case of failure, using "action" to describe what the caller +/// was attempting to do. +template +static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, + StringRef action) { + result.reserve(py::len(list)); + for (py::handle item : list) { + try { + result.push_back(item.cast()); + } catch (py::cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression when ") + action + + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = (llvm::Twine("Invalid expression (None?) when ") + + action + " (" + err.what() + ")") + .str(); + throw py::cast_error(msg); + } + } +} + +template +static bool isPermutation(std::vector permutation) { + llvm::SmallVector seen(permutation.size(), false); + for (auto val : permutation) { + if (val < permutation.size()) { + if (seen[val]) + return false; + seen[val] = true; + continue; + } + return false; + } + return true; +} + +namespace { + +/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr +/// and should be castable from it. Intermediate hierarchy classes can be +/// modeled by specifying BaseTy. +template +class PyConcreteAffineExpr : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirAffineExpr); + + PyConcreteAffineExpr() = default; + PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) + : BaseTy(std::move(contextRef), affineExpr) {} + PyConcreteAffineExpr(PyAffineExpr &orig) + : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} + + static MlirAffineExpr castFrom(PyAffineExpr &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineConstantExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; + static constexpr const char *pyClassName = "AffineConstantExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineConstantExpr get(intptr_t value, + DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = + mlirAffineConstantExprGet(context->get(), static_cast(value)); + return PyAffineConstantExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none()); + c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + return mlirAffineConstantExprGetValue(self); + }); + } +}; + +class PyAffineDimExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; + static constexpr const char *pyClassName = "AffineDimExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); + return PyAffineDimExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineDimExpr &self) { + return mlirAffineDimExprGetPosition(self); + }); + } +}; + +class PyAffineSymbolExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; + static constexpr const char *pyClassName = "AffineSymbolExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { + MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); + return PyAffineSymbolExpr(context->getRef(), affineExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + return mlirAffineSymbolExprGetPosition(self); + }); + } +}; + +class PyAffineBinaryExpr : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; + static constexpr const char *pyClassName = "AffineBinaryExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + PyAffineExpr lhs() { + MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); + return PyAffineExpr(getContext(), lhsExpr); + } + + PyAffineExpr rhs() { + MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); + return PyAffineExpr(getContext(), rhsExpr); + } + + static void bindDerived(ClassTy &c) { + c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); + c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + } +}; + +class PyAffineAddExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; + static constexpr const char *pyClassName = "AffineAddExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineAddExpr::get); + } +}; + +class PyAffineMulExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; + static constexpr const char *pyClassName = "AffineMulExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineMulExpr::get); + } +}; + +class PyAffineModExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; + static constexpr const char *pyClassName = "AffineModExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineModExpr::get); + } +}; + +class PyAffineFloorDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; + static constexpr const char *pyClassName = "AffineFloorDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineFloorDivExpr::get); + } +}; + +class PyAffineCeilDivExpr + : public PyConcreteAffineExpr { +public: + static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; + static constexpr const char *pyClassName = "AffineCeilDivExpr"; + using PyConcreteAffineExpr::PyConcreteAffineExpr; + + static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get", &PyAffineCeilDivExpr::get); + } +}; + +} // namespace + +bool PyAffineExpr::operator==(const PyAffineExpr &other) { + return mlirAffineExprEqual(affineExpr, other.affineExpr); +} + +py::object PyAffineExpr::getCapsule() { + return py::reinterpret_steal( + mlirPythonAffineExprToCapsule(*this)); +} + +PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { + MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); + if (mlirAffineExprIsNull(rawAffineExpr)) + throw py::error_already_set(); + return PyAffineExpr( + PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), + rawAffineExpr); +} + +//------------------------------------------------------------------------------ +// PyAffineMap and utilities. +//------------------------------------------------------------------------------ +namespace { + +/// A list of expressions contained in an affine map. Internally these are +/// stored as a consecutive array leading to inexpensive random access. Both +/// the map and the expression are owned by the context so we need not bother +/// with lifetime extension. +class PyAffineMapExprList + : public Sliceable { +public: + static constexpr const char *pyClassName = "AffineExprList"; + + PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirAffineMapGetNumResults(map) : length, + step), + affineMap(map) {} + + intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } + + PyAffineExpr getElement(intptr_t pos) { + return PyAffineExpr(affineMap.getContext(), + mlirAffineMapGetResult(affineMap, pos)); + } + + PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyAffineMapExprList(affineMap, startIndex, length, step); + } + +private: + PyAffineMap affineMap; +}; +} // end namespace + +bool PyAffineMap::operator==(const PyAffineMap &other) { + return mlirAffineMapEqual(affineMap, other.affineMap); +} + +py::object PyAffineMap::getCapsule() { + return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +} + +PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { + MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(rawAffineMap)) + throw py::error_already_set(); + return PyAffineMap( + PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), + rawAffineMap); +} + +//------------------------------------------------------------------------------ +// PyIntegerSet and utilities. +//------------------------------------------------------------------------------ +namespace { + +class PyIntegerSetConstraint { +public: + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} + + PyAffineExpr getExpr() { + return PyAffineExpr(set.getContext(), + mlirIntegerSetGetConstraint(set, pos)); + } + + bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } + + static void bind(py::module &m) { + py::class_(m, "IntegerSetConstraint") + .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) + .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + } + +private: + PyIntegerSet set; + intptr_t pos; +}; + +class PyIntegerSetConstraintList + : public Sliceable { +public: + static constexpr const char *pyClassName = "IntegerSetConstraintList"; + + PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, + step), + set(set) {} + + intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } + + PyIntegerSetConstraint getElement(intptr_t pos) { + return PyIntegerSetConstraint(set, pos); + } + + PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyIntegerSetConstraintList(set, startIndex, length, step); + } + +private: + PyIntegerSet set; +}; +} // namespace + +bool PyIntegerSet::operator==(const PyIntegerSet &other) { + return mlirIntegerSetEqual(integerSet, other.integerSet); +} + +py::object PyIntegerSet::getCapsule() { + return py::reinterpret_steal( + mlirPythonIntegerSetToCapsule(*this)); +} + +PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { + MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); + if (mlirIntegerSetIsNull(rawIntegerSet)) + throw py::error_already_set(); + return PyIntegerSet( + PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), + rawIntegerSet); +} + +void mlir::python::populateIRAffine(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of PyAffineExpr and derived classes. + //---------------------------------------------------------------------------- + py::class_(m, "AffineExpr") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineExpr::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) + .def("__add__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineAddExpr::get(self, other); + }) + .def("__mul__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineMulExpr::get(self, other); + }) + .def("__mod__", + [](PyAffineExpr &self, PyAffineExpr &other) { + return PyAffineModExpr::get(self, other); + }) + .def("__sub__", + [](PyAffineExpr &self, PyAffineExpr &other) { + auto negOne = + PyAffineConstantExpr::get(-1, *self.getContext().get()); + return PyAffineAddExpr::get(self, + PyAffineMulExpr::get(negOne, other)); + }) + .def("__eq__", [](PyAffineExpr &self, + PyAffineExpr &other) { return self == other; }) + .def("__eq__", + [](PyAffineExpr &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineExpr &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineExpr("); + mlirAffineExprPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def_static( + "get_add", &PyAffineAddExpr::get, + "Gets an affine expression containing a sum of two expressions.") + .def_static( + "get_mul", &PyAffineMulExpr::get, + "Gets an affine expression containing a product of two expressions.") + .def_static("get_mod", &PyAffineModExpr::get, + "Gets an affine expression containing the modulo of dividing " + "one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::get, + "Gets an affine expression containing the rounded-down " + "result of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, + "Gets an affine expression containing the rounded-up result " + "of dividing one expression by another.") + .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none(), + "Gets a constant affine expression with the given value.") + .def_static( + "get_dim", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a dimension at the given position.") + .def_static( + "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none(), + "Gets an affine expression of a symbol at the given position.") + .def( + "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, + kDumpDocstring); + PyAffineConstantExpr::bind(m); + PyAffineDimExpr::bind(m); + PyAffineSymbolExpr::bind(m); + PyAffineBinaryExpr::bind(m); + PyAffineAddExpr::bind(m); + PyAffineMulExpr::bind(m); + PyAffineModExpr::bind(m); + PyAffineFloorDivExpr::bind(m); + PyAffineCeilDivExpr::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyAffineMap. + //---------------------------------------------------------------------------- + py::class_(m, "AffineMap") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineMap::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) + .def("__eq__", + [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) + .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__str__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyAffineMap &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("AffineMap("); + mlirAffineMapPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyAffineMap &self) { return self.getContext().getObject(); }, + "Context that owns the Affine Map") + .def( + "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + DefaultingPyMlirContext context) { + SmallVector affineExprs; + pyListToVector( + exprs, affineExprs, "attempting to create an AffineMap"); + MlirAffineMap map = + mlirAffineMapGet(context->get(), dimCount, symbolCount, + affineExprs.size(), affineExprs.data()); + return PyAffineMap(context->getRef(), map); + }, + py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), + py::arg("context") = py::none(), + "Gets a map with the given expressions as results.") + .def_static( + "get_constant", + [](intptr_t value, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapConstantGet(context->get(), value); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an affine map with a single constant result") + .def_static( + "get_empty", + [](DefaultingPyMlirContext context) { + MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("context") = py::none(), "Gets an empty affine map.") + .def_static( + "get_identity", + [](intptr_t nDims, DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMultiDimIdentityGet(context->get(), nDims); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("context") = py::none(), + "Gets an identity map with the given number of dimensions.") + .def_static( + "get_minor_identity", + [](intptr_t nDims, intptr_t nResults, + DefaultingPyMlirContext context) { + MlirAffineMap affineMap = + mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("n_dims"), py::arg("n_results"), + py::arg("context") = py::none(), + "Gets a minor identity map with the given number of dimensions and " + "results.") + .def_static( + "get_permutation", + [](std::vector permutation, + DefaultingPyMlirContext context) { + if (!isPermutation(permutation)) + throw py::cast_error("Invalid permutation when attempting to " + "create an AffineMap"); + MlirAffineMap affineMap = mlirAffineMapPermutationGet( + context->get(), permutation.size(), permutation.data()); + return PyAffineMap(context->getRef(), affineMap); + }, + py::arg("permutation"), py::arg("context") = py::none(), + "Gets an affine map that permutes its inputs.") + .def("get_submap", + [](PyAffineMap &self, std::vector &resultPos) { + intptr_t numResults = mlirAffineMapGetNumResults(self); + for (intptr_t pos : resultPos) { + if (pos < 0 || pos >= numResults) + throw py::value_error("result position out of bounds"); + } + MlirAffineMap affineMap = mlirAffineMapGetSubMap( + self, resultPos.size(), resultPos.data()); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_major_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMajorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) + .def("get_minor_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMinorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }) + .def_property_readonly( + "is_permutation", + [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) + .def_property_readonly("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_property_readonly( + "n_dims", + [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) + .def_property_readonly( + "n_inputs", + [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) + .def_property_readonly( + "n_symbols", + [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) + .def_property_readonly("results", [](PyAffineMap &self) { + return PyAffineMapExprList(self); + }); + PyAffineMapExprList::bind(m); + + //---------------------------------------------------------------------------- + // Mapping of PyIntegerSet. + //---------------------------------------------------------------------------- + py::class_(m, "IntegerSet") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyIntegerSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) + .def("__eq__", [](PyIntegerSet &self, + PyIntegerSet &other) { return self == other; }) + .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__str__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }) + .def("__repr__", + [](PyIntegerSet &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append("IntegerSet("); + mlirIntegerSetPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "context", + [](PyIntegerSet &self) { return self.getContext().getObject(); }) + .def( + "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, + kDumpDocstring) + .def_static( + "get", + [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + std::vector eqFlags, DefaultingPyMlirContext context) { + if (exprs.size() != eqFlags.size()) + throw py::value_error( + "Expected the number of constraints to match " + "that of equality flags"); + if (exprs.empty()) + throw py::value_error("Expected non-empty list of constraints"); + + // Copy over to a SmallVector because std::vector has a + // specialization for booleans that packs data and does not + // expose a `bool *`. + SmallVector flags(eqFlags.begin(), eqFlags.end()); + + SmallVector affineExprs; + pyListToVector(exprs, affineExprs, + "attempting to create an IntegerSet"); + MlirIntegerSet set = mlirIntegerSetGet( + context->get(), numDims, numSymbols, exprs.size(), + affineExprs.data(), flags.data()); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), + py::arg("eq_flags"), py::arg("context") = py::none()) + .def_static( + "get_empty", + [](intptr_t numDims, intptr_t numSymbols, + DefaultingPyMlirContext context) { + MlirIntegerSet set = + mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); + return PyIntegerSet(context->getRef(), set); + }, + py::arg("num_dims"), py::arg("num_symbols"), + py::arg("context") = py::none()) + .def("get_replaced", + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + intptr_t numResultDims, intptr_t numResultSymbols) { + if (static_cast(dimExprs.size()) != + mlirIntegerSetGetNumDims(self)) + throw py::value_error( + "Expected the number of dimension replacement expressions " + "to match that of dimensions"); + if (static_cast(symbolExprs.size()) != + mlirIntegerSetGetNumSymbols(self)) + throw py::value_error( + "Expected the number of symbol replacement expressions " + "to match that of symbols"); + + SmallVector dimAffineExprs, symbolAffineExprs; + pyListToVector( + dimExprs, dimAffineExprs, + "attempting to create an IntegerSet by replacing dimensions"); + pyListToVector( + symbolExprs, symbolAffineExprs, + "attempting to create an IntegerSet by replacing symbols"); + MlirIntegerSet set = mlirIntegerSetReplaceGet( + self, dimAffineExprs.data(), symbolAffineExprs.data(), + numResultDims, numResultSymbols); + return PyIntegerSet(self.getContext(), set); + }) + .def_property_readonly("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_property_readonly( + "n_dims", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) + .def_property_readonly( + "n_symbols", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) + .def_property_readonly( + "n_inputs", + [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) + .def_property_readonly("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_property_readonly("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_property_readonly("constraints", [](PyIntegerSet &self) { + return PyIntegerSetConstraintList(self); + }); + PyIntegerSetConstraint::bind(m); + PyIntegerSetConstraintList::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp new file mode 100644 index 000000000..6f9206c1b --- /dev/null +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -0,0 +1,761 @@ +//===- IRAttributes.cpp - Exports builtin and standard attributes ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::StringRef; +using llvm::Twine; + +namespace { + +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirAttribute); + + PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); + cls.def(py::init(), py::keep_alive<0, 1>()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyAffineMapAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; + static constexpr const char *pyClassName = "AffineMapAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyAffineMap &affineMap) { + MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); + return PyAffineMapAttribute(affineMap.getContext(), attr); + }, + py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + } +}; + +class PyArrayAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; + static constexpr const char *pyClassName = "ArrayAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + class PyArrayAttributeIterator { + public: + PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} + + PyArrayAttributeIterator &dunderIter() { return *this; } + + PyAttribute dunderNext() { + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { + throw py::stop_iteration(); + } + return PyAttribute(attr.getContext(), + mlirArrayAttrGetElement(attr.get(), nextIndex++)); + } + + static void bind(py::module &m) { + py::class_(m, "ArrayAttributeIterator") + .def("__iter__", &PyArrayAttributeIterator::dunderIter) + .def("__next__", &PyArrayAttributeIterator::dunderNext); + } + + private: + PyAttribute attr; + int nextIndex = 0; + }; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](py::list attributes, DefaultingPyMlirContext context) { + SmallVector mlirAttributes; + mlirAttributes.reserve(py::len(attributes)); + for (auto attribute : attributes) { + try { + mlirAttributes.push_back(attribute.cast()); + } catch (py::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + // This exception seems thrown when the value is "None". + std::string msg = + std::string("Invalid attribute (None?) when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } + } + MlirAttribute attr = mlirArrayAttrGet( + context->get(), mlirAttributes.size(), mlirAttributes.data()); + return PyArrayAttribute(context->getRef(), attr); + }, + py::arg("attributes"), py::arg("context") = py::none(), + "Gets a uniqued Array attribute"); + c.def("__getitem__", + [](PyArrayAttribute &arr, intptr_t i) { + if (i >= mlirArrayAttrGetNumElements(arr)) + throw py::index_error("ArrayAttribute index out of range"); + return PyAttribute(arr.getContext(), + mlirArrayAttrGetElement(arr, i)); + }) + .def("__len__", + [](const PyArrayAttribute &arr) { + return mlirArrayAttrGetNumElements(arr); + }) + .def("__iter__", [](const PyArrayAttribute &arr) { + return PyArrayAttributeIterator(arr); + }); + } +}; + +/// Float Point Attribute subclass - FloatAttr. +class PyFloatAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; + static constexpr const char *pyClassName = "FloatAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, double value, DefaultingPyLocation loc) { + MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(type)).cast() + + "' and expected floating point type."); + } + return PyFloatAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), + "Gets an uniqued float point attribute associated to a type"); + c.def_static( + "get_f32", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF32TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued float point attribute associated to a f32 type"); + c.def_static( + "get_f64", + [](double value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirFloatAttrDoubleGet( + context->get(), mlirF64TypeGet(context->get()), value); + return PyFloatAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued float point attribute associated to a f64 type"); + c.def_property_readonly( + "value", + [](PyFloatAttribute &self) { + return mlirFloatAttrGetValueDouble(self); + }, + "Returns the value of the float point attribute"); + } +}; + +/// Integer Attribute subclass - IntegerAttr. +class PyIntegerAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; + static constexpr const char *pyClassName = "IntegerAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &type, int64_t value) { + MlirAttribute attr = mlirIntegerAttrGet(type, value); + return PyIntegerAttribute(type.getContext(), attr); + }, + py::arg("type"), py::arg("value"), + "Gets an uniqued integer attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyIntegerAttribute &self) { + return mlirIntegerAttrGetValueInt(self); + }, + "Returns the value of the integer attribute"); + } +}; + +/// Bool Attribute subclass - BoolAttr. +class PyBoolAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; + static constexpr const char *pyClassName = "BoolAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](bool value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirBoolAttrGet(context->get(), value); + return PyBoolAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued bool attribute"); + c.def_property_readonly( + "value", + [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, + "Returns the value of the bool attribute"); + } +}; + +class PyFlatSymbolRefAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; + static constexpr const char *pyClassName = "FlatSymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); + return PyFlatSymbolRefAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued FlatSymbolRef attribute"); + c.def_property_readonly( + "value", + [](PyFlatSymbolRefAttribute &self) { + MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the FlatSymbolRef attribute as a string"); + } +}; + +class PyStringAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; + static constexpr const char *pyClassName = "StringAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get_typed", + [](PyType &type, std::string value) { + MlirAttribute attr = + mlirStringAttrTypedGet(type, toMlirStringRef(value)); + return PyStringAttribute(type.getContext(), attr); + }, + + "Gets a uniqued string attribute associated to a type"); + c.def_property_readonly( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); + } +}; + +// TODO: Support construction of bool elements. +// TODO: Support construction of string elements. +class PyDenseElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; + static constexpr const char *pyClassName = "DenseElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseElementsAttribute + getFromBuffer(py::buffer array, bool signless, + DefaultingPyMlirContext contextWrapper) { + // Request a contiguous view. In exotic cases, this will cause a copy. + int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; + Py_buffer *view = new Py_buffer(); + if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { + delete view; + throw py::error_already_set(); + } + py::buffer_info arrayInfo(view); + + MlirContext context = contextWrapper->get(); + // Switch on the types that can be bulk loaded between the Python and + // MLIR-C APIs. + // See: https://docs.python.org/3/library/struct.html#format-characters + if (arrayInfo.format == "f") { + // f32 + assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper->getRef(), + bulkLoad(context, mlirDenseElementsAttrFloatGet, + mlirF32TypeGet(context), arrayInfo)); + } else if (arrayInfo.format == "d") { + // f64 + assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); + return PyDenseElementsAttribute( + contextWrapper->getRef(), + bulkLoad(context, mlirDenseElementsAttrDoubleGet, + mlirF64TypeGet(context), arrayInfo)); + } else if (isSignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // i32 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // i64 + MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrInt64Get, + elementType, arrayInfo)); + } + } else if (isUnsignedIntegerFormat(arrayInfo.format)) { + if (arrayInfo.itemsize == 4) { + // unsigned i32 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt32Get, + elementType, arrayInfo)); + } else if (arrayInfo.itemsize == 8) { + // unsigned i64 + MlirType elementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + return PyDenseElementsAttribute(contextWrapper->getRef(), + bulkLoad(context, + mlirDenseElementsAttrUInt64Get, + elementType, arrayInfo)); + } + } + + // TODO: Fall back to string-based get. + std::string message = "unimplemented array format conversion from format: "; + message.append(arrayInfo.format); + throw SetPyError(PyExc_ValueError, message); + } + + static PyDenseElementsAttribute getSplat(PyType shapedType, + PyAttribute &elementAttr) { + auto contextWrapper = + PyMlirContext::forContext(mlirTypeGetContext(shapedType)); + if (!mlirAttributeIsAInteger(elementAttr) && + !mlirAttributeIsAFloat(elementAttr)) { + std::string message = "Illegal element type for DenseElementsAttr: "; + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + if (!mlirTypeIsAShaped(shapedType) || + !mlirShapedTypeHasStaticShape(shapedType)) { + std::string message = + "Expected a static ShapedType for the shaped_type parameter: "; + message.append(py::repr(py::cast(shapedType))); + throw SetPyError(PyExc_ValueError, message); + } + MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); + MlirType attrType = mlirAttributeGetType(elementAttr); + if (!mlirTypeEqual(shapedElementType, attrType)) { + std::string message = + "Shaped element type and attribute type must be equal: shaped="; + message.append(py::repr(py::cast(shapedType))); + message.append(", element="); + message.append(py::repr(py::cast(elementAttr))); + throw SetPyError(PyExc_ValueError, message); + } + + MlirAttribute elements = + mlirDenseElementsAttrSplatGet(shapedType, elementAttr); + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + + intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } + + py::buffer_info accessBuffer() { + MlirType shapedType = mlirAttributeGetType(*this); + MlirType elementType = mlirShapedTypeGetElementType(shapedType); + + if (mlirTypeIsAF32(elementType)) { + // f32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); + } else if (mlirTypeIsAF64(elementType)) { + // f64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i32 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 64) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i64 + return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); + } + } + + std::string message = "unimplemented array format."; + throw SetPyError(PyExc_ValueError, message); + } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDenseElementsAttribute::dunderLen) + .def_static("get", PyDenseElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("signless") = true, + py::arg("context") = py::none(), + "Gets from a buffer or ndarray") + .def_static("get_splat", PyDenseElementsAttribute::getSplat, + py::arg("shaped_type"), py::arg("element_attr"), + "Gets a DenseElementsAttr where all values are the same") + .def_property_readonly("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def_buffer(&PyDenseElementsAttribute::accessBuffer); + } + +private: + template + static MlirAttribute + bulkLoad(MlirContext context, + MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), + MlirType mlirElementType, py::buffer_info &arrayInfo) { + SmallVector shape(arrayInfo.shape.begin(), + arrayInfo.shape.begin() + arrayInfo.ndim); + auto shapedType = + mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + intptr_t numElements = arrayInfo.size; + const ElementTy *contents = static_cast(arrayInfo.ptr); + return ctor(shapedType, numElements, contents); + } + + static bool isUnsignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'I' || code == 'B' || code == 'H' || code == 'L' || + code == 'Q'; + } + + static bool isSignedIntegerFormat(const std::string &format) { + if (format.empty()) + return false; + char code = format[0]; + return code == 'i' || code == 'b' || code == 'h' || code == 'l' || + code == 'q'; + } + + template + py::buffer_info bufferInfo(MlirType shapedType, + Type (*value)(MlirAttribute, intptr_t)) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); + // Prepare the data for the buffer_info. + // Buffer is configured for read-only access below. + Type *data = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + // Prepare the shape for the buffer_info. + SmallVector shape; + for (intptr_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + // Prepare the strides for the buffer_info. + SmallVector strides; + intptr_t strideFactor = 1; + for (intptr_t i = 1; i < rank; ++i) { + strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) { + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + } + strides.push_back(sizeof(Type) * strideFactor); + } + strides.push_back(sizeof(Type)); + return py::buffer_info(data, sizeof(Type), + py::format_descriptor::format(), rank, shape, + strides, /*readonly=*/true); + } +}; // namespace + +/// Refinement of the PyDenseElementsAttribute for attributes containing integer +/// (and boolean) values. Supports element access. +class PyDenseIntElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; + static constexpr const char *pyClassName = "DenseIntElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + /// Returns the element at the given linear position. Asserts if the index is + /// out of range. + py::int_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + assert(mlirTypeIsAInteger(type) && + "expected integer element type in dense int elements attribute"); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::int_ is implicitly constructible + // from any C++ integral type and handles bitwidth correctly. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + unsigned width = mlirIntegerTypeGetWidth(type); + bool isUnsigned = mlirIntegerTypeIsUnsigned(type); + if (isUnsigned) { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(*this, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetUInt32Value(*this, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetUInt64Value(*this, pos); + } + } else { + if (width == 1) { + return mlirDenseElementsAttrGetBoolValue(*this, pos); + } + if (width == 32) { + return mlirDenseElementsAttrGetInt32Value(*this, pos); + } + if (width == 64) { + return mlirDenseElementsAttrGetInt64Value(*this, pos); + } + } + throw SetPyError(PyExc_TypeError, "Unsupported integer type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); + } +}; + +class PyDictAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; + static constexpr const char *pyClassName = "DictAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + + static void bindDerived(ClassTy &c) { + c.def("__len__", &PyDictAttribute::dunderLen); + c.def_static( + "get", + [](py::dict attributes, DefaultingPyMlirContext context) { + SmallVector mlirNamedAttributes; + mlirNamedAttributes.reserve(attributes.size()); + for (auto &it : attributes) { + auto &mlir_attr = it.second.cast(); + auto name = it.first.cast(); + mlirNamedAttributes.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), + toMlirStringRef(name)), + mlir_attr)); + } + MlirAttribute attr = + mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), + mlirNamedAttributes.data()); + return PyDictAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets an uniqued dict attribute"); + c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { + MlirAttribute attr = + mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); + if (mlirAttributeIsNull(attr)) { + throw SetPyError(PyExc_KeyError, + "attempt to access a non-existent attribute"); + } + return PyAttribute(self.getContext(), attr); + }); + c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { + if (index < 0 || index >= self.dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds attribute"); + } + MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); + return PyNamedAttribute( + namedAttr.attribute, + std::string(mlirIdentifierStr(namedAttr.name).data)); + }); + } +}; + +/// Refinement of PyDenseElementsAttribute for attributes containing +/// floating-point values. Supports element access. +class PyDenseFPElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; + static constexpr const char *pyClassName = "DenseFPElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + py::float_ dunderGetItem(intptr_t pos) { + if (pos < 0 || pos >= dunderLen()) { + throw SetPyError(PyExc_IndexError, + "attempt to access out of bounds element"); + } + + MlirType type = mlirAttributeGetType(*this); + type = mlirShapedTypeGetElementType(type); + // Dispatch element extraction to an appropriate C function based on the + // elemental type of the attribute. py::float_ is implicitly constructible + // from float and double. + // TODO: consider caching the type properties in the constructor to avoid + // querying them on each element access. + if (mlirTypeIsAF32(type)) { + return mlirDenseElementsAttrGetFloatValue(*this, pos); + } + if (mlirTypeIsAF64(type)) { + return mlirDenseElementsAttrGetDoubleValue(*this, pos); + } + throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); + } + + static void bindDerived(ClassTy &c) { + c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); + } +}; + +class PyTypeAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; + static constexpr const char *pyClassName = "TypeAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType value, DefaultingPyMlirContext context) { + MlirAttribute attr = mlirTypeAttrGet(value.get()); + return PyTypeAttribute(context->getRef(), attr); + }, + py::arg("value"), py::arg("context") = py::none(), + "Gets a uniqued Type attribute"); + c.def_property_readonly("value", [](PyTypeAttribute &self) { + return PyType(self.getContext()->getRef(), + mlirTypeAttrGetValue(self.get())); + }); + } +}; + +/// Unit Attribute subclass. Unit attributes don't have values. +class PyUnitAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; + static constexpr const char *pyClassName = "UnitAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + return PyUnitAttribute(context->getRef(), + mlirUnitAttrGet(context->get())); + }, + py::arg("context") = py::none(), "Create a Unit attribute."); + } +}; + +} // namespace + +void mlir::python::populateIRAttributes(py::module &m) { + PyAffineMapAttribute::bind(m); + PyArrayAttribute::bind(m); + PyArrayAttribute::PyArrayAttributeIterator::bind(m); + PyBoolAttribute::bind(m); + PyDenseElementsAttribute::bind(m); + PyDenseFPElementsAttribute::bind(m); + PyDenseIntElementsAttribute::bind(m); + PyDictAttribute::bind(m); + PyFlatSymbolRefAttribute::bind(m); + PyFloatAttribute::bind(m); + PyIntegerAttribute::bind(m); + PyStringAttribute::bind(m); + PyTypeAttribute::bind(m); + PyUnitAttribute::bind(m); +} diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRCore.cpp similarity index 52% rename from mlir/lib/Bindings/Python/IRModules.cpp rename to mlir/lib/Bindings/Python/IRCore.cpp index 6b4e5434d..9d87aa52f 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,16 +6,14 @@ // //===----------------------------------------------------------------------===// -#include "IRModules.h" +#include "IRModule.h" #include "Globals.h" #include "PybindUtils.h" -#include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir-c/IntegerSet.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -138,12 +136,6 @@ py::object classmethod(Func f, Args... args) { return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); } -/// Checks whether the given type is an integer or float type. -static int mlirTypeIsAIntegerOrFloat(MlirType type) { - return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || - mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); -} - static py::object createCustomDialectWrapper(const std::string &dialectNamespace, py::object dialectDescriptor) { @@ -161,21 +153,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -template -static bool isPermutation(std::vector permutation) { - llvm::SmallVector seen(permutation.size(), false); - for (auto val : permutation) { - if (val < permutation.size()) { - if (seen[val]) - return false; - seen[val] = true; - continue; - } - return false; - } - return true; -} - //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -1466,7 +1443,8 @@ namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed /// to accommodate other levels unless core MLIR changes. -template class PyConcreteValue : public PyValue { +template +class PyConcreteValue : public PyValue { public: // Derived classes must define statics for: // IsAFunctionTy isaFunction @@ -1717,1910 +1695,169 @@ class PyOpAttributeMap { } // end namespace //------------------------------------------------------------------------------ -// Builtin attribute subclasses. +// Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -namespace { - -/// CRTP base classes for Python attributes that subclass Attribute and should -/// be castable from it (i.e. via something like StringAttr(attr)). -/// By default, attribute class hierarchies are one level deep (i.e. a -/// concrete attribute class extends PyAttribute); however, intermediate -/// python-visible base classes can be modeled by specifying a BaseTy. -template -class PyConcreteAttribute : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAttribute); - - PyConcreteAttribute() = default; - PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseTy(std::move(contextRef), attr) {} - PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} - - static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineMapAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; - static constexpr const char *pyClassName = "AffineMapAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyAffineMap &affineMap) { - MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); - return PyAffineMapAttribute(affineMap.getContext(), attr); - }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - } -}; - -class PyArrayAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; - static constexpr const char *pyClassName = "ArrayAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - class PyArrayAttributeIterator { - public: - PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} - - PyArrayAttributeIterator &dunderIter() { return *this; } - - PyAttribute dunderNext() { - if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { - throw py::stop_iteration(); - } - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); - } - - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator") - .def("__iter__", &PyArrayAttributeIterator::dunderIter) - .def("__next__", &PyArrayAttributeIterator::dunderNext); - } - - private: - PyAttribute attr; - int nextIndex = 0; - }; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](py::list attributes, DefaultingPyMlirContext context) { - SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); - for (auto attribute : attributes) { - try { - mlirAttributes.push_back(attribute.cast()); - } catch (py::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting " - "to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - // This exception seems thrown when the value is "None". - std::string msg = - std::string("Invalid attribute (None?) when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); +void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of MlirContext + //---------------------------------------------------------------------------- + py::class_(m, "Context") + .def(py::init<>(&PyMlirContext::createNewContextForInit)) + .def_static("_get_live_count", &PyMlirContext::getLiveCount) + .def("_get_context_again", + [](PyMlirContext &self) { + PyMlirContextRef ref = PyMlirContext::forContext(self.get()); + return ref.releaseObject(); + }) + .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyMlirContext::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) + .def("__enter__", &PyMlirContext::contextEnter) + .def("__exit__", &PyMlirContext::contextExit) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *context = PyThreadContextEntry::getDefaultContext(); + if (!context) + throw SetPyError(PyExc_ValueError, "No current Context"); + return context; + }, + "Gets the Context bound to the current thread or raises ValueError") + .def_property_readonly( + "dialects", + [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Gets a container for accessing dialects by name") + .def_property_readonly( + "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, + "Alias for 'dialect'") + .def( + "get_dialect_descriptor", + [=](PyMlirContext &self, std::string &name) { + MlirDialect dialect = mlirContextGetOrLoadDialect( + self.get(), {name.data(), name.size()}); + if (mlirDialectIsNull(dialect)) { + throw SetPyError(PyExc_ValueError, + Twine("Dialect '") + name + "' not found"); } - } - MlirAttribute attr = mlirArrayAttrGet( - context->get(), mlirAttributes.size(), mlirAttributes.data()); - return PyArrayAttribute(context->getRef(), attr); - }, - py::arg("attributes"), py::arg("context") = py::none(), - "Gets a uniqued Array attribute"); - c.def("__getitem__", - [](PyArrayAttribute &arr, intptr_t i) { - if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), - mlirArrayAttrGetElement(arr, i)); - }) - .def("__len__", - [](const PyArrayAttribute &arr) { - return mlirArrayAttrGetNumElements(arr); - }) - .def("__iter__", [](const PyArrayAttribute &arr) { - return PyArrayAttributeIterator(arr); - }); - } -}; - -/// Float Point Attribute subclass - FloatAttr. -class PyFloatAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; - static constexpr const char *pyClassName = "FloatAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, double value, DefaultingPyLocation loc) { - MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } - return PyFloatAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), - "Gets an uniqued float point attribute associated to a type"); - c.def_static( - "get_f32", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF32TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f32 type"); - c.def_static( - "get_f64", - [](double value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirFloatAttrDoubleGet( - context->get(), mlirF64TypeGet(context->get()), value); - return PyFloatAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly( - "value", - [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self); - }, - "Returns the value of the float point attribute"); - } -}; - -/// Integer Attribute subclass - IntegerAttr. -class PyIntegerAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAInteger; - static constexpr const char *pyClassName = "IntegerAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &type, int64_t value) { - MlirAttribute attr = mlirIntegerAttrGet(type, value); - return PyIntegerAttribute(type.getContext(), attr); - }, - py::arg("type"), py::arg("value"), - "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self); - }, - "Returns the value of the integer attribute"); - } -}; - -/// Bool Attribute subclass - BoolAttr. -class PyBoolAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsABool; - static constexpr const char *pyClassName = "BoolAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](bool value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirBoolAttrGet(context->get(), value); - return PyBoolAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued bool attribute"); - c.def_property_readonly( - "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, - "Returns the value of the bool attribute"); - } -}; - -class PyFlatSymbolRefAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; - static constexpr const char *pyClassName = "FlatSymbolRefAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); - return PyFlatSymbolRefAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( - "value", - [](PyFlatSymbolRefAttribute &self) { - MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the FlatSymbolRef attribute as a string"); - } -}; - -class PyStringAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; - static constexpr const char *pyClassName = "StringAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::string value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get_typed", - [](PyType &type, std::string value) { - MlirAttribute attr = - mlirStringAttrTypedGet(type, toMlirStringRef(value)); - return PyStringAttribute(type.getContext(), attr); - }, - - "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); - } -}; - -// TODO: Support construction of bool elements. -// TODO: Support construction of string elements. -class PyDenseElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseElements; - static constexpr const char *pyClassName = "DenseElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, - DefaultingPyMlirContext contextWrapper) { - // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { - delete view; - throw py::error_already_set(); - } - py::buffer_info arrayInfo(view); - - MlirContext context = contextWrapper->get(); - // Switch on the types that can be bulk loaded between the Python and - // MLIR-C APIs. - // See: https://docs.python.org/3/library/struct.html#format-characters - if (arrayInfo.format == "f") { - // f32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrFloatGet, - mlirF32TypeGet(context), arrayInfo)); - } else if (arrayInfo.format == "d") { - // f64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrDoubleGet, - mlirF64TypeGet(context), arrayInfo)); - } else if (isSignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // i32 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // i64 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); - } - } else if (isUnsignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // unsigned i32 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); - } else if (arrayInfo.itemsize == 8) { - // unsigned i64 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); - } - } - - // TODO: Fall back to string-based get. - std::string message = "unimplemented array format conversion from format: "; - message.append(arrayInfo.format); - throw SetPyError(PyExc_ValueError, message); - } - - static PyDenseElementsAttribute getSplat(PyType shapedType, - PyAttribute &elementAttr) { - auto contextWrapper = - PyMlirContext::forContext(mlirTypeGetContext(shapedType)); - if (!mlirAttributeIsAInteger(elementAttr) && - !mlirAttributeIsAFloat(elementAttr)) { - std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - if (!mlirTypeIsAShaped(shapedType) || - !mlirShapedTypeHasStaticShape(shapedType)) { - std::string message = - "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw SetPyError(PyExc_ValueError, message); - } - MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); - MlirType attrType = mlirAttributeGetType(elementAttr); - if (!mlirTypeEqual(shapedElementType, attrType)) { - std::string message = - "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); - message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); - } - - MlirAttribute elements = - mlirDenseElementsAttrSplatGet(shapedType, elementAttr); - return PyDenseElementsAttribute(contextWrapper->getRef(), elements); - } + return PyDialectDescriptor(self.getRef(), dialect); + }, + "Gets or loads a dialect by name, returning its descriptor object") + .def_property( + "allow_unregistered_dialects", + [](PyMlirContext &self) -> bool { + return mlirContextGetAllowUnregisteredDialects(self.get()); + }, + [](PyMlirContext &self, bool value) { + mlirContextSetAllowUnregisteredDialects(self.get(), value); + }); - intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - - py::buffer_info accessBuffer() { - MlirType shapedType = mlirAttributeGetType(*this); - MlirType elementType = mlirShapedTypeGetElementType(shapedType); - - if (mlirTypeIsAF32(elementType)) { - // f32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); - } else if (mlirTypeIsAF64(elementType)) { - // f64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 32) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); - } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 64) { - if (mlirIntegerTypeIsSignless(elementType) || - mlirIntegerTypeIsSigned(elementType)) { - // i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { - // unsigned i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); - } - } + //---------------------------------------------------------------------------- + // Mapping of PyDialectDescriptor + //---------------------------------------------------------------------------- + py::class_(m, "DialectDescriptor") + .def_property_readonly("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = + mlirDialectGetNamespace(self.get()); + return py::str(ns.data, ns.length); + }) + .def("__repr__", [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + std::string repr(""); + return repr; + }); - std::string message = "unimplemented array format."; - throw SetPyError(PyExc_ValueError, message); - } + //---------------------------------------------------------------------------- + // Mapping of PyDialects + //---------------------------------------------------------------------------- + py::class_(m, "Dialects") + .def("__getitem__", + [=](PyDialects &self, std::string keyName) { + MlirDialect dialect = + self.getDialectForKey(keyName, /*attrError=*/false); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(keyName, std::move(descriptor)); + }) + .def("__getattr__", [=](PyDialects &self, std::string attrName) { + MlirDialect dialect = + self.getDialectForKey(attrName, /*attrError=*/true); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); + return createCustomDialectWrapper(attrName, std::move(descriptor)); + }); - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDenseElementsAttribute::dunderLen) - .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("context") = py::none(), - "Gets from a buffer or ndarray") - .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), - "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); - } + //---------------------------------------------------------------------------- + // Mapping of PyDialect + //---------------------------------------------------------------------------- + py::class_(m, "Dialect") + .def(py::init(), "descriptor") + .def_property_readonly( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](py::object self) { + auto clazz = self.attr("__class__"); + return py::str(""); + }); -private: - template - static MlirAttribute - bulkLoad(MlirContext context, - MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), - MlirType mlirElementType, py::buffer_info &arrayInfo) { - SmallVector shape(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); - intptr_t numElements = arrayInfo.size; - const ElementTy *contents = static_cast(arrayInfo.ptr); - return ctor(shapedType, numElements, contents); - } - - static bool isUnsignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'I' || code == 'B' || code == 'H' || code == 'L' || - code == 'Q'; - } - - static bool isSignedIntegerFormat(const std::string &format) { - if (format.empty()) - return false; - char code = format[0]; - return code == 'i' || code == 'b' || code == 'h' || code == 'l' || - code == 'q'; - } - - template - py::buffer_info bufferInfo(MlirType shapedType, - Type (*value)(MlirAttribute, intptr_t)) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); - // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. - Type *data = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - // Prepare the shape for the buffer_info. - SmallVector shape; - for (intptr_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); - // Prepare the strides for the buffer_info. - SmallVector strides; - intptr_t strideFactor = 1; - for (intptr_t i = 1; i < rank; ++i) { - strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) { - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); - } - strides.push_back(sizeof(Type) * strideFactor); - } - strides.push_back(sizeof(Type)); - return py::buffer_info(data, sizeof(Type), - py::format_descriptor::format(), rank, shape, - strides, /*readonly=*/true); - } -}; // namespace - -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. -class PyDenseIntElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseIntElements; - static constexpr const char *pyClassName = "DenseIntElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible - // from any C++ integral type and handles bitwidth correctly. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - unsigned width = mlirIntegerTypeGetWidth(type); - bool isUnsigned = mlirIntegerTypeIsUnsigned(type); - if (isUnsigned) { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); - } - } else { - if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); - } - if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); - } - if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); - } - } - throw SetPyError(PyExc_TypeError, "Unsupported integer type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseIntElementsAttribute::dunderGetItem); - } -}; - -class PyDictAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; - static constexpr const char *pyClassName = "DictAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } - - static void bindDerived(ClassTy &c) { - c.def("__len__", &PyDictAttribute::dunderLen); - c.def_static( - "get", - [](py::dict attributes, DefaultingPyMlirContext context) { - SmallVector mlirNamedAttributes; - mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlir_attr = it.second.cast(); - auto name = it.first.cast(); - mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), - toMlirStringRef(name)), - mlir_attr)); - } - MlirAttribute attr = - mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), - mlirNamedAttributes.data()); - return PyDictAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an uniqued dict attribute"); - c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { - MlirAttribute attr = - mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); - } - return PyAttribute(self.getContext(), attr); - }); - c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { - if (index < 0 || index >= self.dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); - } - MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); - return PyNamedAttribute( - namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); - }); - } -}; - -/// Refinement of PyDenseElementsAttribute for attributes containing -/// floating-point values. Supports element access. -class PyDenseFPElementsAttribute - : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseFPElements; - static constexpr const char *pyClassName = "DenseFPElementsAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - py::float_ dunderGetItem(intptr_t pos) { - if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); - } - - MlirType type = mlirAttributeGetType(*this); - type = mlirShapedTypeGetElementType(type); - // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible - // from float and double. - // TODO: consider caching the type properties in the constructor to avoid - // querying them on each element access. - if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); - } - if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); - } - throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); - } - - static void bindDerived(ClassTy &c) { - c.def("__getitem__", &PyDenseFPElementsAttribute::dunderGetItem); - } -}; - -class PyTypeAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; - static constexpr const char *pyClassName = "TypeAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType value, DefaultingPyMlirContext context) { - MlirAttribute attr = mlirTypeAttrGet(value.get()); - return PyTypeAttribute(context->getRef(), attr); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); - }); - } -}; - -/// Unit Attribute subclass. Unit attributes don't have values. -class PyUnitAttribute : public PyConcreteAttribute { -public: - static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; - static constexpr const char *pyClassName = "UnitAttr"; - using PyConcreteAttribute::PyConcreteAttribute; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - return PyUnitAttribute(context->getRef(), - mlirUnitAttrGet(context->get())); - }, - py::arg("context") = py::none(), "Create a Unit attribute."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// Builtin type subclasses. -//------------------------------------------------------------------------------ - -namespace { - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyIntegerType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - py::arg("width"), py::arg("context") = py::none(), - "Create an unsigned integer type"); - c.def_property_readonly( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_property_readonly( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_property_readonly( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_property_readonly( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a index type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - py::arg("context") = py::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - }, - "Create a complex type"); - c.def_property_readonly( - "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns element type."); - } -}; - -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - } - -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, - "calling this method requires that the type has a rank."); - } - } -}; - -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); - } -}; - -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyRankedTensorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - } -}; - -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedTensorType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("loc") = py::none(), - "Create a unranked tensor type"); - } -}; - -class PyMemRefLayoutMapList; - -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; - - PyMemRefLayoutMapList getLayout(); - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - std::vector layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); - - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyMemRefType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") - .def_property_readonly( - "memory_space", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given MemRef type."); - } -}; - -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( - "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; - -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); - return PyTupleType(context->getRef(), t); - }, - py::arg("elements"), py::arg("context") = py::none(), - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); - }, - "Returns the pos-th type in the tuple type."); - c.def_property_readonly( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; - -/// Function type. -class PyFunctionType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector inputs, std::vector results, - DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); - return PyFunctionType(context->getRef(), t); - }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_property_readonly( - "results", - [](PyFunctionType &self) { - auto contextRef = self.getContext(); - py::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; - -} // namespace - -//------------------------------------------------------------------------------ -// PyAffineExpr and subclasses. -//------------------------------------------------------------------------------ - -namespace { -/// CRTP base class for Python MLIR affine expressions that subclass AffineExpr -/// and should be castable from it. Intermediate hierarchy classes can be -/// modeled by specifying BaseTy. -template -class PyConcreteAffineExpr : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAffineExpr); - - PyConcreteAffineExpr() = default; - PyConcreteAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) - : BaseTy(std::move(contextRef), affineExpr) {} - PyConcreteAffineExpr(PyAffineExpr &orig) - : PyConcreteAffineExpr(orig.getContext(), castFrom(orig)) {} - - static MlirAffineExpr castFrom(PyAffineExpr &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - Twine("Cannot cast affine expression to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -class PyAffineConstantExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAConstant; - static constexpr const char *pyClassName = "AffineConstantExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineConstantExpr get(intptr_t value, - DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = - mlirAffineConstantExprGet(context->get(), static_cast(value)); - return PyAffineConstantExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { - return mlirAffineConstantExprGetValue(self); - }); - } -}; - -class PyAffineDimExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsADim; - static constexpr const char *pyClassName = "AffineDimExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineDimExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineDimExprGet(context->get(), pos); - return PyAffineDimExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { - return mlirAffineDimExprGetPosition(self); - }); - } -}; - -class PyAffineSymbolExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsASymbol; - static constexpr const char *pyClassName = "AffineSymbolExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineSymbolExpr get(intptr_t pos, DefaultingPyMlirContext context) { - MlirAffineExpr affineExpr = mlirAffineSymbolExprGet(context->get(), pos); - return PyAffineSymbolExpr(context->getRef(), affineExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { - return mlirAffineSymbolExprGetPosition(self); - }); - } -}; - -class PyAffineBinaryExpr : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsABinary; - static constexpr const char *pyClassName = "AffineBinaryExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - PyAffineExpr lhs() { - MlirAffineExpr lhsExpr = mlirAffineBinaryOpExprGetLHS(get()); - return PyAffineExpr(getContext(), lhsExpr); - } - - PyAffineExpr rhs() { - MlirAffineExpr rhsExpr = mlirAffineBinaryOpExprGetRHS(get()); - return PyAffineExpr(getContext(), rhsExpr); - } - - static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); - } -}; - -class PyAffineAddExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAAdd; - static constexpr const char *pyClassName = "AffineAddExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); - return PyAffineAddExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineAddExpr::get); - } -}; - -class PyAffineMulExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMul; - static constexpr const char *pyClassName = "AffineMulExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); - return PyAffineMulExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineMulExpr::get); - } -}; - -class PyAffineModExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAMod; - static constexpr const char *pyClassName = "AffineModExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); - return PyAffineModExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineModExpr::get); - } -}; - -class PyAffineFloorDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsAFloorDiv; - static constexpr const char *pyClassName = "AffineFloorDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); - return PyAffineFloorDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineFloorDivExpr::get); - } -}; - -class PyAffineCeilDivExpr - : public PyConcreteAffineExpr { -public: - static constexpr IsAFunctionTy isaFunction = mlirAffineExprIsACeilDiv; - static constexpr const char *pyClassName = "AffineCeilDivExpr"; - using PyConcreteAffineExpr::PyConcreteAffineExpr; - - static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { - MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); - return PyAffineCeilDivExpr(lhs.getContext(), expr); - } - - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineCeilDivExpr::get); - } -}; -} // namespace - -bool PyAffineExpr::operator==(const PyAffineExpr &other) { - return mlirAffineExprEqual(affineExpr, other.affineExpr); -} - -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); -} - -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { - MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); - if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); - return PyAffineExpr( - PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), - rawAffineExpr); -} - -//------------------------------------------------------------------------------ -// PyAffineMap and utilities. -//------------------------------------------------------------------------------ - -namespace { -/// A list of expressions contained in an affine map. Internally these are -/// stored as a consecutive array leading to inexpensive random access. Both -/// the map and the expression are owned by the context so we need not bother -/// with lifetime extension. -class PyAffineMapExprList - : public Sliceable { -public: - static constexpr const char *pyClassName = "AffineExprList"; - - PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirAffineMapGetNumResults(map) : length, - step), - affineMap(map) {} - - intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } - - PyAffineExpr getElement(intptr_t pos) { - return PyAffineExpr(affineMap.getContext(), - mlirAffineMapGetResult(affineMap, pos)); - } - - PyAffineMapExprList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyAffineMapExprList(affineMap, startIndex, length, step); - } - -private: - PyAffineMap affineMap; -}; -} // end namespace - -bool PyAffineMap::operator==(const PyAffineMap &other) { - return mlirAffineMapEqual(affineMap, other.affineMap); -} - -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); -} - -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { - MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); - if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); - return PyAffineMap( - PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), - rawAffineMap); -} - -//------------------------------------------------------------------------------ -// PyIntegerSet and utilities. -//------------------------------------------------------------------------------ - -class PyIntegerSetConstraint { -public: - PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} - - PyAffineExpr getExpr() { - return PyAffineExpr(set.getContext(), - mlirIntegerSetGetConstraint(set, pos)); - } - - bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint") - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); - } - -private: - PyIntegerSet set; - intptr_t pos; -}; - -class PyIntegerSetConstraintList - : public Sliceable { -public: - static constexpr const char *pyClassName = "IntegerSetConstraintList"; - - PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, - step), - set(set) {} - - intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } - - PyIntegerSetConstraint getElement(intptr_t pos) { - return PyIntegerSetConstraint(set, pos); - } - - PyIntegerSetConstraintList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyIntegerSetConstraintList(set, startIndex, length, step); - } - -private: - PyIntegerSet set; -}; - -bool PyIntegerSet::operator==(const PyIntegerSet &other) { - return mlirIntegerSetEqual(integerSet, other.integerSet); -} - -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); -} - -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { - MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); - if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); - return PyIntegerSet( - PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), - rawIntegerSet); -} - -/// Attempts to populate `result` with the content of `list` casted to the -/// appropriate type (Python and C types are provided as template arguments). -/// Throws errors in case of failure, using "action" to describe what the caller -/// was attempting to do. -template -static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, - StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { - try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression when ") + action + - " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - std::string msg = (llvm::Twine("Invalid expression (None?) when ") + - action + " (" + err.what() + ")") - .str(); - throw py::cast_error(msg); - } - } -} - -//------------------------------------------------------------------------------ -// Populates the pybind11 IR submodule. -//------------------------------------------------------------------------------ - -void mlir::python::populateIRSubmodule(py::module &m) { - //---------------------------------------------------------------------------- - // Mapping of MlirContext - //---------------------------------------------------------------------------- - py::class_(m, "Context") - .def(py::init<>(&PyMlirContext::createNewContextForInit)) - .def_static("_get_live_count", &PyMlirContext::getLiveCount) - .def("_get_context_again", - [](PyMlirContext &self) { - PyMlirContextRef ref = PyMlirContext::forContext(self.get()); - return ref.releaseObject(); - }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) - .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *context = PyThreadContextEntry::getDefaultContext(); - if (!context) - throw SetPyError(PyExc_ValueError, "No current Context"); - return context; - }, - "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( - "dialects", - [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Gets a container for accessing dialects by name") - .def_property_readonly( - "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, - "Alias for 'dialect'") - .def( - "get_dialect_descriptor", - [=](PyMlirContext &self, std::string &name) { - MlirDialect dialect = mlirContextGetOrLoadDialect( - self.get(), {name.data(), name.size()}); - if (mlirDialectIsNull(dialect)) { - throw SetPyError(PyExc_ValueError, - Twine("Dialect '") + name + "' not found"); - } - return PyDialectDescriptor(self.getRef(), dialect); - }, - "Gets or loads a dialect by name, returning its descriptor object") - .def_property( - "allow_unregistered_dialects", - [](PyMlirContext &self) -> bool { - return mlirContextGetAllowUnregisteredDialects(self.get()); - }, - [](PyMlirContext &self, bool value) { - mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialectDescriptor - //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor") - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) - .def("__repr__", [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - std::string repr(""); - return repr; - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialects - //---------------------------------------------------------------------------- - py::class_(m, "Dialects") - .def("__getitem__", - [=](PyDialects &self, std::string keyName) { - MlirDialect dialect = - self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(keyName, std::move(descriptor)); - }) - .def("__getattr__", [=](PyDialects &self, std::string attrName) { - MlirDialect dialect = - self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); - return createCustomDialectWrapper(attrName, std::move(descriptor)); - }); - - //---------------------------------------------------------------------------- - // Mapping of PyDialect - //---------------------------------------------------------------------------- - py::class_(m, "Dialect") - .def(py::init(), "descriptor") - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { - auto clazz = self.attr("__class__"); - return py::str(""); - }); - - //---------------------------------------------------------------------------- - // Mapping of Location - //---------------------------------------------------------------------------- - py::class_(m, "Location") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) - .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) - .def("__eq__", - [](PyLocation &self, PyLocation &other) -> bool { - return mlirLocationEqual(self, other); - }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( - "current", - [](py::object & /*class*/) { - auto *loc = PyThreadContextEntry::getDefaultLocation(); - if (!loc) - throw SetPyError(PyExc_ValueError, "No current Location"); - return loc; - }, - "Gets the Location bound to the current thread or raises ValueError") - .def_static( - "unknown", - [](DefaultingPyMlirContext context) { - return PyLocation(context->getRef(), - mlirLocationUnknownGet(context->get())); - }, - py::arg("context") = py::none(), - "Gets a Location representing an unknown location") - .def_static( - "file", - [](std::string filename, int line, int col, - DefaultingPyMlirContext context) { - return PyLocation( - context->getRef(), - mlirLocationFileLineColGet( - context->get(), toMlirStringRef(filename), line, col)); - }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) - .def_property_readonly( - "context", - [](PyLocation &self) { return self.getContext().getObject(); }, - "Context that owns the Location") - .def("__repr__", [](PyLocation &self) { - PyPrintAccumulator printAccum; - mlirLocationPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }); + //---------------------------------------------------------------------------- + // Mapping of Location + //---------------------------------------------------------------------------- + py::class_(m, "Location") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) + .def("__enter__", &PyLocation::contextEnter) + .def("__exit__", &PyLocation::contextExit) + .def("__eq__", + [](PyLocation &self, PyLocation &other) -> bool { + return mlirLocationEqual(self, other); + }) + .def("__eq__", [](PyLocation &self, py::object other) { return false; }) + .def_property_readonly_static( + "current", + [](py::object & /*class*/) { + auto *loc = PyThreadContextEntry::getDefaultLocation(); + if (!loc) + throw SetPyError(PyExc_ValueError, "No current Location"); + return loc; + }, + "Gets the Location bound to the current thread or raises ValueError") + .def_static( + "unknown", + [](DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationUnknownGet(context->get())); + }, + py::arg("context") = py::none(), + "Gets a Location representing an unknown location") + .def_static( + "file", + [](std::string filename, int line, int col, + DefaultingPyMlirContext context) { + return PyLocation( + context->getRef(), + mlirLocationFileLineColGet( + context->get(), toMlirStringRef(filename), line, col)); + }, + py::arg("filename"), py::arg("line"), py::arg("col"), + py::arg("context") = py::none(), kContextGetFileLocationDocstring) + .def_property_readonly( + "context", + [](PyLocation &self) { return self.getContext().getObject(); }, + "Context that owns the Location") + .def("__repr__", [](PyLocation &self) { + PyPrintAccumulator printAccum; + mlirLocationPrint(self, printAccum.getCallback(), + printAccum.getUserData()); + return printAccum.join(); + }); //---------------------------------------------------------------------------- // Mapping of Module @@ -4022,22 +2259,6 @@ void mlir::python::populateIRSubmodule(py::module &m) { py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); - // Builtin attribute bindings. - PyAffineMapAttribute::bind(m); - PyArrayAttribute::bind(m); - PyArrayAttribute::PyArrayAttributeIterator::bind(m); - PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); - PyDenseFPElementsAttribute::bind(m); - PyDenseIntElementsAttribute::bind(m); - PyDictAttribute::bind(m); - PyFlatSymbolRefAttribute::bind(m); - PyFloatAttribute::bind(m); - PyIntegerAttribute::bind(m); - PyStringAttribute::bind(m); - PyTypeAttribute::bind(m); - PyUnitAttribute::bind(m); - //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- @@ -4088,25 +2309,6 @@ void mlir::python::populateIRSubmodule(py::module &m) { return printAccum.join(); }); - // Builtin type bindings. - PyIntegerType::bind(m); - PyIndexType::bind(m); - PyBF16Type::bind(m); - PyF16Type::bind(m); - PyF32Type::bind(m); - PyF64Type::bind(m); - PyNoneType::bind(m); - PyComplexType::bind(m); - PyShapedType::bind(m); - PyVectorType::bind(m); - PyRankedTensorType::bind(m); - PyUnrankedTensorType::bind(m); - PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); - PyUnrankedMemRefType::bind(m); - PyTupleType::bind(m); - PyFunctionType::bind(m); - //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- @@ -4152,359 +2354,4 @@ void mlir::python::populateIRSubmodule(py::module &m) { PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineExpr and derived classes. - //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); - }) - .def("__sub__", - [](PyAffineExpr &self, PyAffineExpr &other) { - auto negOne = - PyAffineConstantExpr::get(-1, *self.getContext().get()); - return PyAffineAddExpr::get(self, - PyAffineMulExpr::get(negOne, other)); - }) - .def("__eq__", [](PyAffineExpr &self, - PyAffineExpr &other) { return self == other; }) - .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineExpr &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineExpr("); - mlirAffineExprPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineExpr &self) { return self.getContext().getObject(); }) - .def_static( - "get_add", &PyAffineAddExpr::get, - "Gets an affine expression containing a sum of two expressions.") - .def_static( - "get_mul", &PyAffineMulExpr::get, - "Gets an affine expression containing a product of two expressions.") - .def_static("get_mod", &PyAffineModExpr::get, - "Gets an affine expression containing the modulo of dividing " - "one expression by another.") - .def_static("get_floor_div", &PyAffineFloorDivExpr::get, - "Gets an affine expression containing the rounded-down " - "result of dividing one expression by another.") - .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, - "Gets an affine expression containing the rounded-up result " - "of dividing one expression by another.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), - "Gets a constant affine expression with the given value.") - .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a dimension at the given position.") - .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), - "Gets an affine expression of a symbol at the given position.") - .def( - "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, - kDumpDocstring); - PyAffineConstantExpr::bind(m); - PyAffineDimExpr::bind(m); - PyAffineSymbolExpr::bind(m); - PyAffineBinaryExpr::bind(m); - PyAffineAddExpr::bind(m); - PyAffineMulExpr::bind(m); - PyAffineModExpr::bind(m); - PyAffineFloorDivExpr::bind(m); - PyAffineCeilDivExpr::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyAffineMap. - //---------------------------------------------------------------------------- - py::class_(m, "AffineMap") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) - .def("__eq__", - [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) - .def("__str__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyAffineMap &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("AffineMap("); - mlirAffineMapPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyAffineMap &self) { return self.getContext().getObject(); }, - "Context that owns the Affine Map") - .def( - "dump", [](PyAffineMap &self) { mlirAffineMapDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, - DefaultingPyMlirContext context) { - SmallVector affineExprs; - pyListToVector( - exprs, affineExprs, "attempting to create an AffineMap"); - MlirAffineMap map = - mlirAffineMapGet(context->get(), dimCount, symbolCount, - affineExprs.size(), affineExprs.data()); - return PyAffineMap(context->getRef(), map); - }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), - "Gets a map with the given expressions as results.") - .def_static( - "get_constant", - [](intptr_t value, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapConstantGet(context->get(), value); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("value"), py::arg("context") = py::none(), - "Gets an affine map with a single constant result") - .def_static( - "get_empty", - [](DefaultingPyMlirContext context) { - MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("context") = py::none(), "Gets an empty affine map.") - .def_static( - "get_identity", - [](intptr_t nDims, DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMultiDimIdentityGet(context->get(), nDims); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("context") = py::none(), - "Gets an identity map with the given number of dimensions.") - .def_static( - "get_minor_identity", - [](intptr_t nDims, intptr_t nResults, - DefaultingPyMlirContext context) { - MlirAffineMap affineMap = - mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), - "Gets a minor identity map with the given number of dimensions and " - "results.") - .def_static( - "get_permutation", - [](std::vector permutation, - DefaultingPyMlirContext context) { - if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); - MlirAffineMap affineMap = mlirAffineMapPermutationGet( - context->get(), permutation.size(), permutation.data()); - return PyAffineMap(context->getRef(), affineMap); - }, - py::arg("permutation"), py::arg("context") = py::none(), - "Gets an affine map that permutes its inputs.") - .def("get_submap", - [](PyAffineMap &self, std::vector &resultPos) { - intptr_t numResults = mlirAffineMapGetNumResults(self); - for (intptr_t pos : resultPos) { - if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); - } - MlirAffineMap affineMap = mlirAffineMapGetSubMap( - self, resultPos.size(), resultPos.data()); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_major_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMajorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_minor_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMinorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def_property_readonly( - "is_permutation", - [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( - "n_dims", - [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( - "n_inputs", - [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( - "n_symbols", - [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); - PyAffineMapExprList::bind(m); - - //---------------------------------------------------------------------------- - // Mapping of PyIntegerSet. - //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) - .def("__eq__", [](PyIntegerSet &self, - PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) - .def("__str__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); - }) - .def("__repr__", - [](PyIntegerSet &self) { - PyPrintAccumulator printAccum; - printAccum.parts.append("IntegerSet("); - mlirIntegerSetPrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); - }) - .def_property_readonly( - "context", - [](PyIntegerSet &self) { return self.getContext().getObject(); }) - .def( - "dump", [](PyIntegerSet &self) { mlirIntegerSetDump(self); }, - kDumpDocstring) - .def_static( - "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, - std::vector eqFlags, DefaultingPyMlirContext context) { - if (exprs.size() != eqFlags.size()) - throw py::value_error( - "Expected the number of constraints to match " - "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); - - // Copy over to a SmallVector because std::vector has a - // specialization for booleans that packs data and does not - // expose a `bool *`. - SmallVector flags(eqFlags.begin(), eqFlags.end()); - - SmallVector affineExprs; - pyListToVector(exprs, affineExprs, - "attempting to create an IntegerSet"); - MlirIntegerSet set = mlirIntegerSetGet( - context->get(), numDims, numSymbols, exprs.size(), - affineExprs.data(), flags.data()); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) - .def_static( - "get_empty", - [](intptr_t numDims, intptr_t numSymbols, - DefaultingPyMlirContext context) { - MlirIntegerSet set = - mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); - return PyIntegerSet(context->getRef(), set); - }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) - .def("get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { - if (static_cast(dimExprs.size()) != - mlirIntegerSetGetNumDims(self)) - throw py::value_error( - "Expected the number of dimension replacement expressions " - "to match that of dimensions"); - if (static_cast(symbolExprs.size()) != - mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( - "Expected the number of symbol replacement expressions " - "to match that of symbols"); - - SmallVector dimAffineExprs, symbolAffineExprs; - pyListToVector( - dimExprs, dimAffineExprs, - "attempting to create an IntegerSet by replacing dimensions"); - pyListToVector( - symbolExprs, symbolAffineExprs, - "attempting to create an IntegerSet by replacing symbols"); - MlirIntegerSet set = mlirIntegerSetReplaceGet( - self, dimAffineExprs.data(), symbolAffineExprs.data(), - numResultDims, numResultSymbols); - return PyIntegerSet(self.getContext(), set); - }) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( - "n_dims", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( - "n_symbols", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( - "n_inputs", - [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { - return PyIntegerSetConstraintList(self); - }); - PyIntegerSetConstraint::bind(m); - PyIntegerSetConstraintList::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModules.h b/mlir/lib/Bindings/Python/IRModule.h similarity index 99% rename from mlir/lib/Bindings/Python/IRModules.h rename to mlir/lib/Bindings/Python/IRModule.h index 8140d7043..5c710abe7 100644 --- a/mlir/lib/Bindings/Python/IRModules.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -747,7 +747,10 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet integerSet; }; -void populateIRSubmodule(pybind11::module &m); +void populateIRAffine(pybind11::module &m); +void populateIRAttributes(pybind11::module &m); +void populateIRCore(pybind11::module &m); +void populateIRTypes(pybind11::module &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp new file mode 100644 index 000000000..96f6bf666 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -0,0 +1,678 @@ +//===- IRTypes.cpp - Exports builtin and standard types -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" + +#include "PybindUtils.h" + +#include "mlir-c/BuiltinTypes.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +using llvm::SmallVector; +using llvm::Twine; + +namespace { + +/// Checks whether the given type is an integer or float type. +static int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = py::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = py::repr(py::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +class PyIntegerType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + py::arg("width"), py::arg("context") = py::none(), + "Create an unsigned integer type"); + c.def_property_readonly( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_property_readonly( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_property_readonly( + "is_signed", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSigned(self); + }, + "Returns whether this is a signed integer"); + c.def_property_readonly( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); + } +}; + +/// Index Type subclass - IndexType. +class PyIndexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a index type."); + } +}; + +/// Floating Point Type subclass - BF16Type. +class PyBF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a bf16 type."); + } +}; + +/// Floating Point Type subclass - F16Type. +class PyF16Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f16 type."); + } +}; + +/// Floating Point Type subclass - F32Type. +class PyF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f32 type."); + } +}; + +/// Floating Point Type subclass - F64Type. +class PyF64Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a f64 type."); + } +}; + +/// None Type subclass - NoneType. +class PyNoneType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a none type."); + } +}; + +/// Complex Type subclass - ComplexType. +class PyComplexType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + }, + "Create a complex type"); + c.def_property_readonly( + "element_type", + [](PyComplexType &self) -> PyType { + MlirType t = mlirComplexTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns element type."); + } +}; + +class PyShapedType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { + MlirType t = mlirShapedTypeGetElementType(self); + return PyType(self.getContext(), t); + }, + "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + } + +private: + void requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw SetPyError( + PyExc_ValueError, + "calling this method requires that the type has a rank."); + } + } +}; + +/// Vector Type subclass - VectorType. +class PyVectorType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type."); + } + return PyVectorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), + "Create a vector type"); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class PyRankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + DefaultingPyLocation loc) { + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyRankedTensorType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), + "Create a ranked tensor type"); + } +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class PyUnrankedTensorType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedTensorType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("loc") = py::none(), + "Create a unranked tensor type"); + } +}; + +class PyMemRefLayoutMapList; + +/// Ranked MemRef Type subclass - MemRefType. +class PyMemRefType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + PyMemRefLayoutMapList getLayout(); + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector shape, PyType &elementType, + std::vector layout, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + SmallVector maps; + maps.reserve(layout.size()); + for (PyAffineMap &map : layout) + maps.push_back(map); + + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), maps.size(), + maps.data(), memSpaceAttr); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyMemRefType(elementType.getContext(), t); + }, + py::arg("shape"), py::arg("element_type"), + py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), + py::arg("loc") = py::none(), "Create a memref type") + .def_property_readonly("layout", &PyMemRefType::getLayout, + "The list of layout maps of the MemRef type.") + .def_property_readonly( + "memory_space", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given MemRef type."); + } +}; + +/// A list of affine layout maps in a memref type. Internally, these are stored +/// as consecutive elements, random access is cheap. Both the type and the maps +/// are owned by the context, no need to worry about lifetime extension. +class PyMemRefLayoutMapList + : public Sliceable { +public: + static constexpr const char *pyClassName = "MemRefLayoutMapList"; + + PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, + step), + memref(type) {} + + intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } + + PyAffineMap getElement(intptr_t index) { + return PyAffineMap(memref.getContext(), + mlirMemRefTypeGetAffineMap(memref, index)); + } + + PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyMemRefLayoutMapList(memref, startIndex, length, step); + } + +private: + PyMemRefType memref; +}; + +PyMemRefLayoutMapList PyMemRefType::getLayout() { + return PyMemRefLayoutMapList(*this); +} + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class PyUnrankedMemRefType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + // TODO: Rework error reporting once diagnostic engine is exposed + // in C API. + if (mlirTypeIsNull(t)) { + throw SetPyError( + PyExc_ValueError, + Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point, integer, vector or " + "complex " + "type."); + } + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") + .def_property_readonly( + "memory_space", + [](PyUnrankedMemRefType &self) -> PyAttribute { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + return PyAttribute(self.getContext(), a); + }, + "Returns the memory space of the given Unranked MemRef type."); + } +}; + +/// Tuple Type subclass - TupleType. +class PyTupleType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](py::list elementList, DefaultingPyMlirContext context) { + intptr_t num = py::len(elementList); + // Mapping py::list to SmallVector. + SmallVector elements; + for (auto element : elementList) + elements.push_back(element.cast()); + MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); + return PyTupleType(context->getRef(), t); + }, + py::arg("elements"), py::arg("context") = py::none(), + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, intptr_t pos) -> PyType { + MlirType t = mlirTupleTypeGetType(self, pos); + return PyType(self.getContext(), t); + }, + "Returns the pos-th type in the tuple type."); + c.def_property_readonly( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); + } +}; + +/// Function type. +class PyFunctionType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector inputs, std::vector results, + DefaultingPyMlirContext context) { + SmallVector inputsRaw(inputs.begin(), inputs.end()); + SmallVector resultsRaw(results.begin(), results.end()); + MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), + inputsRaw.data(), resultsRaw.size(), + resultsRaw.data()); + return PyFunctionType(context->getRef(), t); + }, + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_property_readonly( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_property_readonly( + "results", + [](PyFunctionType &self) { + auto contextRef = self.getContext(); + py::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append( + PyType(contextRef, mlirFunctionTypeGetResult(self, i))); + } + return types; + }, + "Returns the list of result types in the FunctionType."); + } +}; + +} // namespace + +void mlir::python::populateIRTypes(py::module &m) { + PyIntegerType::bind(m); + PyIndexType::bind(m); + PyBF16Type::bind(m); + PyF16Type::bind(m); + PyF32Type::bind(m); + PyF64Type::bind(m); + PyNoneType::bind(m); + PyComplexType::bind(m); + PyShapedType::bind(m); + PyVectorType::bind(m); + PyRankedTensorType::bind(m); + PyUnrankedTensorType::bind(m); + PyMemRefType::bind(m); + PyMemRefLayoutMapList::bind(m); + PyUnrankedMemRefType::bind(m); + PyTupleType::bind(m); + PyFunctionType::bind(m); +} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 9bfe8b09f..5fe0401af 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,7 +12,7 @@ #include "ExecutionEngine.h" #include "Globals.h" -#include "IRModules.h" +#include "IRModule.h" #include "Pass.h" namespace py = pybind11; @@ -211,7 +211,10 @@ PYBIND11_MODULE(_mlir, m) { // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); - populateIRSubmodule(irModule); + populateIRCore(irModule); + populateIRAffine(irModule); + populateIRAttributes(irModule); + populateIRTypes(irModule); // Define and populate PassManager submodule. auto passModule = diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index dd57647f0..0e2f5bafb 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,7 +8,7 @@ #include "Pass.h" -#include "IRModules.h" +#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" From 0a5b1ee4d5b8a8a6d3826ad38857441a09d9a604 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 19 Mar 2021 15:43:42 -0700 Subject: [PATCH 005/915] [mlir][python] Function decorator for capturing a FuncOp from a python function. * Moves this out of a test case where it was being developed to good effect and generalizes it. * Having tried a number of things like this, I think this balances concerns reasonably well. Differential Revision: https://reviews.llvm.org/D98989 --- .../Python/mlir/dialects/_builtin_ops_ext.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py index b07892991..dc1d37e76 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py @@ -1,6 +1,11 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence + +import inspect + from ..ir import * @@ -93,3 +98,99 @@ def add_entry_block(self): raise IndexError('The function already has an entry block!') self.body.blocks.append(*self.type.inputs) return self.body.blocks[0] + + @classmethod + def from_py_func(FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import std + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and (param.kind + == param.POSITIONAL_OR_KEYWORD or + param.kind == param.KEYWORD_ONLY): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None.") + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, Value): + return_values = [return_values] + else: + return_values = list(return_values) + std.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get(inputs=inputs, results=return_types) + func_op.attributes["type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), + call_args) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator From 46bc061f523eb125966bd7344410d99462bd5160 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 19 Mar 2021 18:44:51 -0700 Subject: [PATCH 006/915] [mlir][python] Adapt to `segment_sizes` attribute type change. * Broken by https://reviews.llvm.org/rG1a75be0023cd80fd8560d689999a63d4368c90e6 --- mlir/lib/Bindings/Python/IRCore.cpp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 9d87aa52f..0a4c5fcb4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1034,8 +1034,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - std::vector operandSegmentLengths; - std::vector resultSegmentLengths; + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; // Validate/determine region count. auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); @@ -1247,8 +1247,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, // Add result_segment_sizes attribute. if (!resultSegmentLengths.empty()) { int64_t size = resultSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( + mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), resultSegmentLengths.size(), resultSegmentLengths.data()); (*attributes)["result_segment_sizes"] = PyAttribute(context, segmentLengthAttr); @@ -1257,8 +1257,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, // Add operand_segment_sizes attribute. if (!operandSegmentLengths.empty()) { int64_t size = operandSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)), + MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( + mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), operandSegmentLengths.size(), operandSegmentLengths.data()); (*attributes)["operand_segment_sizes"] = PyAttribute(context, segmentLengthAttr); From 7b5688322b9919186bf5bc129d431fc7421c1126 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 11 Mar 2021 23:58:02 +0000 Subject: [PATCH 007/915] Define a `NoTerminator` traits that allows operations with a single block region to not provide a terminator In particular for Graph Regions, the terminator needs is just a historical artifact of the generalization of MLIR from CFG region. Operations like Module don't need a terminator, and before Module migrated to be an operation with region there wasn't any needed. To validate the feature, the ModuleOp is migrated to use this trait and the ModuleTerminator operation is deleted. This patch is likely to break clients, if you're in this case: - you may iterate on a ModuleOp with `getBody()->without_terminator()`, the solution is simple: just remove the ->without_terminator! - you created a builder with `Builder::atBlockTerminator(module_body)`, just use `Builder::atBlockEnd(module_body)` instead. - you were handling ModuleTerminator: it isn't needed anymore. - for generic code, a `Block::mayNotHaveTerminator()` may be used. Differential Revision: https://reviews.llvm.org/D98468 --- mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py index dc1d37e76..6598efe3e 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py @@ -16,8 +16,6 @@ def __init__(self, *, loc=None, ip=None): super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) body = self.regions[0].blocks.append() - with InsertionPoint(body): - Operation.create("module_terminator") @property def body(self): From 668ab83805e39df2e0b0ab296fd18e6ab2387bf3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 19 Mar 2021 18:16:45 -0700 Subject: [PATCH 008/915] [mlir][linalg] Add an InitTensorOp python builder. * This has the API I want but I am not thrilled with the implementation. There are various things that could be improved both about the way that Python builders are mapped and the way the Linalg ops are factored to increase code sharing between C++/Python. * Landing this as-is since it at least makes the InitTensorOp usable with the right API. Will refactor underneath in follow-ons. Differential Revision: https://reviews.llvm.org/D99000 --- .../Python/mlir/dialects/_linalg_ops_ext.py | 42 +++++++++++++++++++ .../Python/mlir/dialects/_ods_common.py | 19 +++++---- 2 files changed, 53 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py index 74390d487..d35d10cc4 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -2,6 +2,48 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional, Sequence, Union +from ..ir import * +from ._ods_common import get_default_loc_context + + +class InitTensorOp: + """Extends the linalg.init_tensor op.""" + + def __init__(self, + sizes: Union[Sequence[int], Sequence[Value]], + element_type: Type, + *, + loc=None, + ip=None): + """Constructs an `init_tensor` with either static or dynamic sizes.""" + context = get_default_loc_context(loc) + operands = [] + attributes = {} + # TODO: Refactor the InitTensorOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + if sizes and isinstance(sizes[0], Value): + # Dynamic sizes. + operands.extend(sizes) + static_size_ints = [-1] * len(sizes) + result_type = RankedTensorType.get(static_size_ints, element_type) + else: + # Static sizes. + result_type = RankedTensorType.get(sizes, element_type) + static_size_ints = sizes + + index_type = IndexType.get(context) + attributes["static_sizes"] = ArrayAttr.get( + [IntegerAttr.get(index_type, s) for s in static_size_ints], + context=context) + op = self.build_generic(results=[result_type], + operands=operands, + attributes=attributes, + loc=loc, + ip=ip) + OpView.__init__(self, op) + class StructuredOpMixin: """All structured ops use the same mixin class.""" diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py index 6d37700ec..d03044088 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py @@ -17,14 +17,15 @@ def extend_opview_class(ext_module): """Decorator to extend an OpView class from an extension module. Extension modules can expose various entry-points: + Stand-alone class with the same name as a parent OpView class (i.e. + "ReturnOp"). A name-based match is attempted first before falling back + to a below mechanism. + def select_opview_mixin(parent_opview_cls): If defined, allows an appropriate mixin class to be selected dynamically based on the parent OpView class. Should return NotImplemented if a decision is not made. - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). - Args: ext_module: A module from which to locate extensions. Can be None if not available. @@ -38,16 +39,18 @@ def class_decorator(parent_opview_cls: type): if ext_module is None: return parent_opview_cls mixin_cls = NotImplemented + # First try to resolve by name. try: - select_mixin = getattr(ext_module, "select_opview_mixin") + mixin_cls = getattr(ext_module, parent_opview_cls.__name__) except AttributeError: - # Try to default resolve it. + # Fall back to a select_opview_mixin hook. try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) + select_mixin = getattr(ext_module, "select_opview_mixin") except AttributeError: pass - else: - mixin_cls = select_mixin(parent_opview_cls) + else: + mixin_cls = select_mixin(parent_opview_cls) + if mixin_cls is NotImplemented or mixin_cls is None: return parent_opview_cls From 3066be7872b275a21a2505363e8ecd2f042459c1 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 25 Mar 2021 15:08:09 +0000 Subject: [PATCH 009/915] [mlir][python] NFC - Fix stale path in doc Differential Revision: https://reviews.llvm.org/D99345 --- .../Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py index 98bf2e247..bacc0c302 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -14,7 +14,7 @@ Sample usage: # Dump the YAML op definitions for the core named ops (as in the dialect # source tree). - python -m mlir.tools.linalg_opdsl.dump_oplib .ops.core_named_ops + python -m mlir.dialects.linalg.opdsl.dump_oplib .ops.core_named_ops Note: YAML output is emitted in "document list" format with each operation as its own "document". Practically, this means that each operation (or group From 3ac0f7e0f82497f950a27921feadad811afea5db Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 26 Mar 2021 19:04:36 +0100 Subject: [PATCH 010/915] [mlir] Register Linalg passes in C API and Python Bindings Provide a registration mechanism for Linalg dialect-specific passes in C API and Python bindings. These are being built into the dialect library but exposed in separate headers (C) or modules (Python). Differential Revision: https://reviews.llvm.org/D99431 --- mlir/include/mlir-c/Dialect/Linalg.h | 3 +++ mlir/lib/Bindings/Python/CMakeLists.txt | 8 ++++++ mlir/lib/Bindings/Python/LinalgPasses.cpp | 22 ++++++++++++++++ .../mlir/dialects/linalg/passes/__init__.py | 6 +++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 7 +++++ mlir/lib/CAPI/Dialect/LinalgPasses.cpp | 26 +++++++++++++++++++ 6 files changed, 72 insertions(+) create mode 100644 mlir/lib/Bindings/Python/LinalgPasses.cpp create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py create mode 100644 mlir/lib/CAPI/Dialect/LinalgPasses.cpp diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 56258ac19..be73a5c8c 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -11,6 +11,7 @@ #define MLIR_C_DIALECT_LINALG_H #include "mlir-c/Registration.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -22,4 +23,6 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); } #endif +#include "mlir/Dialect/Linalg/Passes.capi.h.inc" + #endif // MLIR_C_DIALECT_LINALG_H diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 5fefa8039..43d6275d4 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -118,3 +118,11 @@ endif() add_subdirectory(Transforms) add_subdirectory(Conversions) + +add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses + INSTALL_DIR + python + SOURCES + LinalgPasses.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRLinalgPassesBindingsPythonExtension) diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp new file mode 100644 index 000000000..3f230207a --- /dev/null +++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp @@ -0,0 +1,22 @@ +//===- LinalgPasses.cpp - Pybind module for the Linalg passes -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Linalg.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirLinalgPasses, m) { + m.doc() = "MLIR Linalg Dialect Passes"; + + // Register all Linalg passes on load. + mlirRegisterLinalgPasses(); +} diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py new file mode 100644 index 000000000..6555ad69a --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._cext_loader import _load_extension +_cextLinalgPasses = _load_extension("_mlirLinalgPasses") diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index d256309bf..41c659d6a 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,6 +1,7 @@ # TODO: Make the check source feature optional as an argument on *_add_library. set(LLVM_OPTIONAL_SOURCES Linalg.cpp + LinalgPasses.cpp SCF.cpp Shape.cpp Standard.cpp @@ -9,10 +10,16 @@ set(LLVM_OPTIONAL_SOURCES add_mlir_public_c_api_library(MLIRCAPILinalg Linalg.cpp + LinalgPasses.cpp + + DEPENDS + MLIRLinalgPassIncGen LINK_LIBS PUBLIC MLIRCAPIIR MLIRLinalg + MLIRPass + MLIRLinalgTransforms ) add_mlir_public_c_api_library(MLIRCAPISCF diff --git a/mlir/lib/CAPI/Dialect/LinalgPasses.cpp b/mlir/lib/CAPI/Dialect/LinalgPasses.cpp new file mode 100644 index 000000000..6677476d8 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/LinalgPasses.cpp @@ -0,0 +1,26 @@ +//===- LinalgPasses.cpp - C API for Linalg Dialect Passes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/Linalg/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/Linalg/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/Linalg/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif From 1f00af730522d90e69a3afbbe3f9cf5a69537a09 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 26 Mar 2021 08:40:07 +0000 Subject: [PATCH 011/915] [mlir][Linalg] Allow calling named ops when available and make it the default. Differential Revision: https://reviews.llvm.org/D99419 --- .../Python/mlir/dialects/_linalg_ops_ext.py | 3 -- .../Python/mlir/dialects/linalg/__init__.py | 48 +++++++++++++++++++ .../linalg/opdsl/lang/comprehension.py | 12 ++--- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 15 ++++-- .../dialects/linalg/opdsl/lang/emitter.py | 35 ++++++++++++-- 5 files changed, 95 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py index d35d10cc4..d787943d1 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -49,9 +49,6 @@ class StructuredOpMixin: """All structured ops use the same mixin class.""" def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - if outputs and results: - raise ValueError( - "Structured ops must have outputs or results, but not both.") super().__init__( self.build_generic(results=list(results), operands=[list(inputs), list(outputs)], diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py index 81949b8f8..976718337 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py @@ -2,4 +2,52 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# These are the backing OpView classes generated from the linalg tablegen +# definitions following these steps: +# DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * + +# These are the ground truth functions defined as: +# ``` +# @linalg_structured_op +# def matmul(A=TensorDef(T1, S.M, S.K), +# B=TensorDef(T2, S.K, S.N), +# C=TensorDef(U, S.M, S.N, output=True)): +# ``` +# using the linalg-py eDSL. +# The linalg-py eDSL builds a python representation (PyRepr) that is +# used in following ways: +# 1. PyRepr -> YAML to generate the C++ and Python .td files. These +# then turn into the core C++ Op classes and Python OpView classes +# respectively (made available in _linalg_ops_gen). The generic OpView class +# mechanism makes the C++ classes available to python through the CAPI. +# PyRepr -> YAML currently occurs before compiler compile time. +# The other steps in this category occur at compiler compile time. +# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the +# _linalg_ops_gen classes and the OpView mechanism to build IR at +# runtime in python: +# a. by default, the Named Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR: +# ``` +# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) +# -> tensor<4x8xf32> +# ``` +# b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.: +# `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR: +# ``` +# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} +# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# outs(%0 : tensor<4x8xf32>) { +# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): +# ... +# linalg.yield %3 : f32 +# } -> tensor<4x8xf32> +# ``` +# 3. PyRepr -> Runtime Custom Op definitions: directly generates a +# linalg.generic form like in 2.b. +# !!!WARNING!!!: if one creates a runtime custom op with the same name +# as an existing core named op, step 2. will likely take precedence. +# TODO: guard against surprises and fail create Runtime Custom Ops with +# the same name as existing Core Named Ops. +from .opdsl.ops.core_named_ops import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 6bc6ff979..85da3323c 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -359,16 +359,16 @@ class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" - def __init__(self, name: str, cpp_op_name: Optional[str], doc: Optional[str]): + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): self.name = name - self.cpp_op_name = cpp_op_name if cpp_op_name is not None else name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] def to_yaml_custom_dict(self): d = dict( name=self.name, - cpp_op_name=self.cpp_op_name, + cpp_class_name=self.cpp_class_name, doc=self.doc, ) if self.implements: @@ -381,9 +381,9 @@ class LinalgOpDef: def __init__(self, name: str, - cpp_op_name: Optional[str] = None, + cpp_class_name: Optional[str] = None, doc: Optional[str] = None): - self.metadata = OpMetadataDef(name=name, cpp_op_name=cpp_op_name, doc=doc) + self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_tensors = dict() # type: Dict[str, TensorDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @@ -413,7 +413,7 @@ def tensor(self, name): def __repr__(self): lines = [ - f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_op_name}," + f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," ] for name, tensor in self.registered_tensors.items(): lines.append(f" {tensor}") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py index cbff41db2..d6dc9895f 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -44,7 +44,7 @@ def __init__(self, op_name: str, model: LinalgOpDef): self.op_name = op_name self.model = model - def __call__(self, *args, emit_generic: bool = True, **kwargs): + def __call__(self, *args, emit_generic: bool = False, **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following @@ -61,14 +61,21 @@ def __call__(self, *args, emit_generic: bool = True, **kwargs): raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") + # TODO: this file should probably not be called dsl.py but rather is a client + # of the dsl.py. + from .... import linalg as linalg_ops + emit_generic = (emit_generic or + (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys())) + op_config = op_configs[0] if op_config.structured_op: if emit_generic: return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: - return emit_named_structured_op(op_config.structured_op, *args, - **kwargs) + return emit_named_structured_op( + op_config.structured_op, self.op_name, + self.model.metadata.cpp_class_name, *args, **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -91,7 +98,7 @@ def linalg_structured_op(dsl_func=None, op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" tc_model = LinalgOpDef(name=op_name, - cpp_op_name=op_class_name, + cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py index 9a18993e9..e8e7eb5c3 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -17,9 +17,9 @@ ] -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Value = ()): +def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] @@ -49,6 +49,18 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, [AffineMapAttr.get(am) for am in op_config.indexing_maps]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, + type_mapping, indexing_maps_attr, iterator_types_attr) + + +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: Value = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + generic_op = linalg.GenericOp( result_tensors=out_types, inputs=ins, @@ -77,10 +89,23 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, def emit_named_structured_op(op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, *ins: Value, outs: Value = ()): - raise NotImplementedError( - f"Emission of named structured ops is not supported: {op_config}") + all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs) + + if not op_class_name in linalg.__dict__.keys(): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + + named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + if len(out_arg_defs) == 1: + return named_op.result + else: + return named_op.results class _BodyBuilder: From 97380e9dcf6bec0bd1a960177f9e05b17e87b1ab Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 29 Mar 2021 18:30:50 +0000 Subject: [PATCH 012/915] NFC: Update MLIR python bindings docs to install deps via requirements.txt. * Also adds some verbiage about upgrading `pip` itself, since this is a common source of issues. Differential Revision: https://reviews.llvm.org/D99522 --- mlir/lib/Bindings/Python/requirements.txt | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mlir/lib/Bindings/Python/requirements.txt diff --git a/mlir/lib/Bindings/Python/requirements.txt b/mlir/lib/Bindings/Python/requirements.txt new file mode 100644 index 000000000..51b35c22e --- /dev/null +++ b/mlir/lib/Bindings/Python/requirements.txt @@ -0,0 +1,3 @@ +numpy +pybind11>=2.6.0 +PyYAML From f45dbc8332393a65a13d664cfb11ac2e3fa53ee8 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 30 Mar 2021 04:35:36 +0000 Subject: [PATCH 013/915] Add a "register_runtime" method to the mlir.execution_engine and show calling back from MLIR into Python This exposes the ability to register Python functions with the JIT and exposes them to the MLIR jitted code. The provided test case illustrates the mechanism. Differential Revision: https://reviews.llvm.org/D99562 --- mlir/include/mlir-c/ExecutionEngine.h | 6 ++++++ mlir/lib/Bindings/Python/ExecutionEngine.cpp | 12 +++++++++++- mlir/lib/Bindings/Python/mlir/execution_engine.py | 8 ++++++++ mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 12 ++++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index c25635771..5210f108e 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -61,6 +61,12 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked( MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name); +/// Register a symbol with the jit: this symbol will be accessible to the jitted +/// code. +MLIR_CAPI_EXPORTED void +mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name, + void *sym); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp index 5ca9b1f68..0e8ae8b38 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -81,7 +81,17 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) { auto *res = mlirExecutionEngineLookup( executionEngine.get(), mlirStringRefCreate(func.c_str(), func.size())); - return (int64_t)res; + return reinterpret_cast(res); + }, + "Lookup function `func` in the ExecutionEngine.") + .def( + "raw_register_runtime", + [](PyExecutionEngine &executionEngine, const std::string &name, + uintptr_t sym) { + mlirExecutionEngineRegisterSymbol( + executionEngine.get(), + mlirStringRefCreate(name.c_str(), name.size()), + reinterpret_cast(sym)); }, "Lookup function `func` in the ExecutionEngine."); } diff --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/lib/Bindings/Python/mlir/execution_engine.py index 89bd4aad5..39d9501d9 100644 --- a/mlir/lib/Bindings/Python/mlir/execution_engine.py +++ b/mlir/lib/Bindings/Python/mlir/execution_engine.py @@ -29,3 +29,11 @@ def invoke(self, name, *ctypes_args): for argNum in range(len(ctypes_args)): packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) func(packed_args) + + def register_runtime(self, name, ctypes_callback): + """Register a runtime function available to the jitted code + under the provided `name`. The `ctypes_callback` must be a + `CFuncType` that outlives the execution engine. + """ + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value + self.raw_register_runtime("_mlir_ciface_" + name, callback) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 68137c067..345eac219 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -11,6 +11,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; @@ -54,3 +55,14 @@ extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, return nullptr; return reinterpret_cast(*expectedFPtr); } + +extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, + MlirStringRef name, + void *sym) { + unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { + llvm::orc::SymbolMap symbolMap; + symbolMap[interner(unwrap(name))] = + llvm::JITEvaluatedSymbol::fromPointer(sym); + return symbolMap; + }); +} From 03c480124a06450a573b9031ffdabca9c645dcc5 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 30 Mar 2021 22:19:10 -0700 Subject: [PATCH 014/915] [mlir] Add C and python API for is_registered_operation. * Suggested to be broken out of D99578 Differential Revision: https://reviews.llvm.org/D99638 --- mlir/include/mlir-c/IR.h | 7 +++++++ mlir/lib/Bindings/Python/IRCore.cpp | 7 ++++++- mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d807cd46d..048bd4667 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -119,6 +119,13 @@ mlirContextGetNumLoadedDialects(MlirContext context); MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Returns whether the given fully-qualified operation (i.e. +/// 'dialect.operation') is registered with the context. This will return true +/// if the dialect is loaded and the operation is registered within the +/// dialect. +MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, + MlirStringRef name); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0a4c5fcb4..5046eedb1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1752,7 +1752,12 @@ void mlir::python::populateIRCore(py::module &m) { }, [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); - }); + }) + .def("is_registered_operation", + [](PyMlirContext &self, std::string &name) { + return mlirContextIsRegisteredOperation( + self.get(), MlirStringRef{name.data(), name.size()}); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 67032a4b5..14cde9633 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -60,6 +60,10 @@ MlirDialect mlirContextGetOrLoadDialect(MlirContext context, return wrap(unwrap(context)->getOrLoadDialect(unwrap(name))); } +bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { + return unwrap(context)->isOperationRegistered(unwrap(name)); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// From 2e1215dd189ee1f94b7b2dc92f99c55413c9d4c4 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 30 Mar 2021 11:41:41 +0000 Subject: [PATCH 015/915] [mlir][Linalg][Python] Create the body of builtin named Linalg ops This revision adds support to properly add the body of registered builtin named linalg ops. At this time, indexing_map and iterator_type support is still missing so the op is not executable yet. Differential Revision: https://reviews.llvm.org/D99578 --- mlir/include/mlir-c/Dialect/Linalg.h | 5 +++ mlir/lib/Bindings/Python/CMakeLists.txt | 1 + mlir/lib/Bindings/Python/DialectLinalg.cpp | 34 +++++++++++++++++++ mlir/lib/Bindings/Python/DialectLinalg.h | 22 ++++++++++++ mlir/lib/Bindings/Python/MainModule.cpp | 6 ++++ .../mlir/dialects/linalg/opdsl/lang/dsl.py | 9 +++-- .../dialects/linalg/opdsl/lang/emitter.py | 13 +++++-- mlir/lib/CAPI/Dialect/Linalg.cpp | 29 ++++++++++++++-- 8 files changed, 110 insertions(+), 9 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectLinalg.cpp create mode 100644 mlir/lib/Bindings/Python/DialectLinalg.h diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index be73a5c8c..06f15f062 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -17,6 +17,11 @@ extern "C" { #endif +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +MLIR_CAPI_EXPORTED void +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 43d6275d4..39192cc54 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -69,6 +69,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir INSTALL_DIR python SOURCES + DialectLinalg.cpp MainModule.cpp IRAffine.cpp IRAttributes.cpp diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp new file mode 100644 index 000000000..e4ef69411 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -0,0 +1,34 @@ +//===- DialectLinalg.cpp - Pybind module for Linalg dialect API support --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/Dialect/Linalg.h" +#include "mlir-c/IR.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +namespace mlir { +namespace python { + +void populateDialectLinalgSubmodule(py::module &m) { + m.def( + "fill_builtin_region", + [](PyDialectDescriptor &dialect, PyOperation &op) { + return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); + }, + py::arg("dialect"), py::arg("op"), + "Fill the region for `op`, which is assumed to be a builtin named Linalg " + "op."); +} + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h new file mode 100644 index 000000000..3735dbf6f --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLinalg.h @@ -0,0 +1,22 @@ +//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H +#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateDialectLinalgSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 5fe0401af..79128f267 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -10,6 +10,7 @@ #include "PybindUtils.h" +#include "DialectLinalg.h" #include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" @@ -225,4 +226,9 @@ PYBIND11_MODULE(_mlir, m) { auto executionEngineModule = m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); populateExecutionEngineSubmodule(executionEngineModule); + + // Define and populate Linalg submodule. + auto dialectsModule = m.def_submodule("dialects"); + auto linalgModule = dialectsModule.def_submodule("linalg"); + populateDialectLinalgSubmodule(linalgModule); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py index d6dc9895f..002ae51ba 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -61,11 +61,10 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs): raise NotImplementedError( f"Emission of composite linalg ops not supported: {op_configs}") - # TODO: this file should probably not be called dsl.py but rather is a client - # of the dsl.py. - from .... import linalg as linalg_ops - emit_generic = (emit_generic or - (not self.model.metadata.cpp_class_name in linalg_ops.__dict__.keys())) + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = 'linalg.' + self.op_name + emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] if op_config.structured_op: diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py index e8e7eb5c3..2395a422e 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -7,6 +7,9 @@ from mlir.ir import * from mlir.dialects import linalg from mlir.dialects import std +# TODO: resolve name collision for Linalg functionality that is injected inside +# the _mlir.dialects.linalg directly via pybind. +from _mlir.dialects.linalg import fill_builtin_region from .scalar_expr import * from .config import * @@ -16,7 +19,6 @@ "emit_named_structured_op", ] - def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value): @@ -97,11 +99,18 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, type_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) - if not op_class_name in linalg.__dict__.keys(): + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = 'linalg.' + op_name + if (not ctx.is_registered_operation(fully_qualified_name) or + not op_class_name in linalg.__dict__.keys()): raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + linalgDialect = ctx.get_dialect_descriptor("linalg") + fill_builtin_region(linalgDialect, named_op.operation) + if len(out_arg_defs) == 1: return named_op.result else: diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index da6fd4846..1c50aa612 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -10,5 +10,30 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, - mlir::linalg::LinalgDialect) +using namespace mlir; +using namespace mlir::linalg; + +/// Apply the special region builder for the builtin named Linalg op. +/// Assert that `op` is a builtin named Linalg op. +void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, + MlirOperation mlirOp) { + Operation *op = unwrap(mlirOp); + LinalgDialect::RegionBuilderFunType fun = + static_cast(unwrap(linalgDialect)) + ->getRegionBuilder(op->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); + assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); + assert(op->getRegion(0).getBlocks().empty() && + "Expected Linalg op with 0 blocks"); + SmallVector argTypes; + auto linalgOp = cast(op); + for (auto t : linalgOp.getShapedOperandTypes()) + argTypes.push_back(getElementTypeOrSelf(t)); + OpBuilder b(op->getContext()); + Region ®ion = op->getRegion(0); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); + // TODO: allow captures. + fun(*body, ValueRange{}); +} + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From dd91c6a54ef22bce2cac368d3ffabe28f788c700 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 31 Mar 2021 09:33:08 +0000 Subject: [PATCH 016/915] [mlir][Python][Linalg] Add missing attributes to linalg ops This revision tightens up the handling of attributes for both named and generic linalg ops. To demonstrate the IR validity, a working e2e Linalg example is added. Differential Revision: https://reviews.llvm.org/D99430 --- mlir/include/mlir-c/AffineMap.h | 11 +++++ mlir/lib/Bindings/Python/IRAffine.cpp | 17 +++++++ .../dialects/linalg/opdsl/lang/emitter.py | 45 ++++++++++++++----- mlir/lib/CAPI/IR/AffineMap.cpp | 11 +++++ 4 files changed, 72 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h index de4f42f09..e35b7cc6b 100644 --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -169,6 +169,17 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults); MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults); +/// Returns the simplified affine map resulting from dropping the symbols that +/// do not appear in any of the individual maps in `affineMaps`. +/// Asserts that all maps in `affineMaps` are normalized to the same number of +/// dims and symbols. +/// Takes a callback `populateResult` to fill the `res` container with value +/// `m` at entry `idx`. This allows returning without worrying about ownership +/// considerations. +MLIR_CAPI_EXPORTED void mlirAffineMapCompressUnusedSymbols( + MlirAffineMap *affineMaps, intptr_t size, void *result, + void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 73a57d95e..5d3b790b3 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -538,6 +538,23 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def_static("compress_unused_symbols", + [](py::list affineMaps, DefaultingPyMlirContext context) { + SmallVector maps; + pyListToVector( + affineMaps, maps, "attempting to create an AffineMap"); + std::vector compressed(affineMaps.size()); + auto populate = [](void *result, intptr_t idx, + MlirAffineMap m) { + static_cast(result)[idx] = (m); + }; + mlirAffineMapCompressUnusedSymbols( + maps.data(), maps.size(), compressed.data(), populate); + std::vector res; + for (auto m : compressed) + res.push_back(PyAffineMap(context->getRef(), m)); + return res; + }) .def_property_readonly( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py index 2395a422e..682f19138 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -19,6 +19,13 @@ "emit_named_structured_op", ] +def isa(cls : Type, ty : Type): + try: + cls(ty) + return True + except ValueError: + return False + def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value): @@ -37,6 +44,8 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) + result_types = [t for t in out_types if isa(RankedTensorType, t)] + # Extract type vars for input/output based types. type_mapping = dict() # type: Dict[str, Type] for arg_def, arg_element_type in zip( @@ -48,30 +57,37 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, # Emit the generic op. # TODO: Support emission of pure memref form. indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in op_config.indexing_maps]) + [AffineMapAttr.get(am) + # TODO: linalg verification does not currently allow symbols. + # Compress them for now. + for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + sparse_attr = ArrayAttr.get( + [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)]) + if len(sparse_attr) == 0: + sparse_attr = None - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, - type_mapping, indexing_maps_attr, iterator_types_attr) + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr = \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( - result_tensors=out_types, + result_tensors=result_types, inputs=ins, outputs=outs, indexing_maps=indexing_maps_attr, iterator_types=iterator_types_attr, doc=None, # TODO: Make optional. library_call=None, # TODO: Make optional. - sparse=BoolAttr.get(False)) # TODO: Make optional. + sparse=sparse_attr) # TODO: Make optional. # Construct the body. block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) @@ -84,7 +100,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, body_builder.assign(assignment) body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) - if len(out_arg_defs) == 1: + if len(result_types) == 1: return generic_op.result else: return generic_op.results @@ -95,8 +111,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_class_name: str, *ins: Value, outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, out_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr = \ + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ + type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs) # If we get here, there must exist a builtin class `op_class_name`. @@ -107,11 +123,16 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - named_op = getattr(linalg, op_class_name)(ins, outs, out_types) + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) linalgDialect = ctx.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, named_op.operation) + # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps + # attribute that the non-yaml path does not. The non-yaml path hardcodes the + # indexing_maps in C++ directly. + named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr + # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. - if len(out_arg_defs) == 1: + if len(result_types) == 1: return named_op.result else: return named_op.results diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp index f532d5dae..e0c07afc3 100644 --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -137,3 +137,14 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults) { return wrap(unwrap(affineMap).getMinorSubMap(numResults)); } + +void mlirAffineMapCompressUnusedSymbols( + MlirAffineMap *affineMaps, intptr_t size, void *result, + void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) { + SmallVector maps; + for (intptr_t idx = 0; idx < size; ++idx) + maps.push_back(unwrap(affineMaps[idx])); + intptr_t idx = 0; + for (auto m : mlir::compressUnusedSymbols(maps)) + populateResult(result, idx++, wrap(m)); +} From 098978f95e20e649ae1bc208af4adb5f2399b974 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 6 Apr 2021 14:15:22 -0700 Subject: [PATCH 017/915] [MLIR] [Python] Add capsule methods for pybind11 to PyOperation Add the `getCapsule()` and `createFromCapsule()` methods to the PyOperation class. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D99927 --- mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 8 ++++++++ 2 files changed, 24 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5046eedb1..7a7bae92c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -868,6 +868,19 @@ PyBlock PyOperation::getBlock() { return PyBlock{std::move(parentOperation), block}; } +py::object PyOperation::getCapsule() { + return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); +} + +py::object PyOperation::createFromCapsule(py::object capsule) { + MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); + if (mlirOperationIsNull(rawOperation)) + throw py::error_already_set(); + MlirContext rawCtxt = mlirOperationGetContext(rawOperation); + return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) + .releaseObject(); +} + py::object PyOperation::create( std::string name, llvm::Optional> results, llvm::Optional> operands, @@ -2031,6 +2044,9 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyOperation::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_property_readonly("name", [](PyOperation &self) { MlirOperation operation = self.get(); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 5c710abe7..861673abc 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -454,6 +454,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// no parent. PyOperationRef getParentOperation(); + /// Gets a capsule wrapping the void* within the MlirOperation. + pybind11::object getCapsule(); + + /// Creates a PyOperation from the MlirOperation wrapped by a capsule. + /// Ownership of the underlying MlirOperation is taken by calling this + /// function. + static pybind11::object createFromCapsule(pybind11::object capsule); + /// Creates an operation. See corresponding python docstring. static pybind11::object create(std::string name, llvm::Optional> results, From 6944cd534e82b30a50955f27adadd964937f0d99 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 12 Apr 2021 09:28:41 -0700 Subject: [PATCH 018/915] [mlir] introduce "encoding" attribute to tensor type This CL introduces a generic attribute (called "encoding") on tensors. The attribute currently does not carry any concrete information, but the type system already correctly determines that tensor<8xi1,123> != tensor<8xi1,321>. The attribute will be given meaning through an interface in subsequent CLs. See ongoing discussion on discourse: [RFC] Introduce a sparse tensor type to core MLIR https://llvm.discourse.group/t/rfc-introduce-a-sparse-tensor-type-to-core-mlir/2944 A sparse tensor will look something like this: ``` // named alias with all properties we hold dear: #CSR = { // individual named attributes } // actual sparse tensor type: tensor ``` I see the following rough 5 step plan going forward: (1) introduce this format attribute in this CL, currently still empty (2) introduce attribute interface that gives it "meaning", focused on sparse in first phase (3) rewrite sparse compiler to use new type, remove linalg interface and "glue" (4) teach passes to deal with new attribute, by rejecting/asserting on non-empty attribute as simplest solution, or doing meaningful rewrite in the longer run (5) add FE support, document, test, publicize new features, extend "format" meaning to other domains if useful Reviewed By: stellaraccident, bondhugula Differential Revision: https://reviews.llvm.org/D99548 --- mlir/include/mlir-c/BuiltinAttributes.h | 3 +++ mlir/include/mlir-c/BuiltinTypes.h | 15 +++++++++------ mlir/lib/Bindings/Python/IRAttributes.cpp | 5 +++-- mlir/lib/Bindings/Python/IRTypes.cpp | 4 +++- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 2 ++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 11 ++++++----- 6 files changed, 26 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 29df9cf60..c85825c8d 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -22,6 +22,9 @@ extern "C" { #endif +/// Returns an empty attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(); + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index b2ec37c9d..7d45452af 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -188,17 +188,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); -/// Creates a tensor type of a fixed rank with the given shape and element type -/// in the same context as the element type. The type is owned by the context. +/// Creates a tensor type of a fixed rank with the given shape, element type, +/// and optional encoding in the same context as the element type. The type is +/// owned by the context. Tensor types without any specific encoding field +/// should assign mlirAttributeGetNull() to this parameter. MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType); + MlirType elementType, + MlirAttribute encoding); /// Same as "mlirRankedTensorTypeGet" but returns a nullptr wrapping MlirType on /// illegal arguments, emitting appropriate diagnostics. -MLIR_CAPI_EXPORTED MlirType -mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, - const int64_t *shape, MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked( + MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, + MlirAttribute encoding); /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 6f9206c1b..b5e3c5c9c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -502,8 +502,9 @@ class PyDenseElementsAttribute MlirType mlirElementType, py::buffer_info &arrayInfo) { SmallVector shape(arrayInfo.shape.begin(), arrayInfo.shape.begin() + arrayInfo.ndim); - auto shapedType = - mlirRankedTensorTypeGet(shape.size(), shape.data(), mlirElementType); + MlirAttribute encodingAttr = mlirAttributeGetNull(); + auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + mlirElementType, encodingAttr); intptr_t numElements = arrayInfo.size; const ElementTy *contents = static_cast(arrayInfo.ptr); return ctor(shapedType, numElements, contents); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 96f6bf666..421df4dab 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -10,6 +10,7 @@ #include "PybindUtils.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" namespace py = pybind11; @@ -381,8 +382,9 @@ class PyRankedTensorType "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType); + loc, shape.size(), shape.data(), elementType, encodingAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index a54006db2..7580786de 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -15,6 +15,8 @@ using namespace mlir; +MlirAttribute mlirAttributeGetNull() { return {nullptr}; } + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index c84ced177..1e5fa8a32 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -191,18 +191,19 @@ bool mlirTypeIsAUnrankedTensor(MlirType type) { } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, MlirAttribute encoding) { return wrap(RankedTensorType::get( - llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + unwrap(encoding))); } MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, - MlirType elementType) { + MlirType elementType, + MlirAttribute encoding) { return wrap(RankedTensorType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + unwrap(elementType), unwrap(encoding))); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { From d19f307a0b1044f8df73bcbcfc96964a0ea3ff7f Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 15 Apr 2021 19:55:56 +0000 Subject: [PATCH 019/915] Add support for numpy arrays to memref conversions. This offers the ability to pass numpy arrays to the corresponding memref argument. Reviewed By: mehdi_amini, nicolasvasilache Differential Revision: https://reviews.llvm.org/D100077 --- .../Bindings/Python/mlir/runtime/__init__.py | 1 + .../Python/mlir/runtime/np_to_memref.py | 119 ++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 mlir/lib/Bindings/Python/mlir/runtime/__init__.py create mode 100644 mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py new file mode 100644 index 000000000..8a28fd935 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/__init__.py @@ -0,0 +1 @@ +from .np_to_memref import * diff --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py new file mode 100644 index 000000000..43ef95435 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py @@ -0,0 +1,119 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +# This file contains functions to convert between Memrefs and NumPy arrays and vice-versa. + +import numpy as np +import ctypes + + +def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given rank/dtype, where rank>0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor + + +def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """ + Build an empty descriptor for the given dtype, where rank=0. + """ + + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + + return MemRefDescriptor + + +class UnrankedMemRefDescriptor(ctypes.Structure): + """ Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + + +def get_ranked_memref_descriptor(nparray): + """ + Return a ranked memref descriptor for the given numpy array. + """ + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor( + nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) + )() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as( + ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) + ) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + + +def get_unranked_memref_descriptor(nparray): + """ + Return a generic/unranked memref descriptor for the given numpy array. + """ + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d + + +def unranked_memref_to_numpy(unranked_memref, np_dtype): + """ + Converts unranked memrefs to numpy arrays. + """ + descriptor = make_nd_memref_descriptor( + unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) + ) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return strided_arr + + +def ranked_memref_to_numpy(ranked_memref): + """ + Converts ranked memrefs to numpy arrays. + """ + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape + ) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return strided_arr From f2ded9d1c4b06182224c2ac3bc64e8ae676fa56f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 13 Apr 2021 06:25:47 +0000 Subject: [PATCH 020/915] [mlir][Python][Linalg] Add support for captures in body builder. When Linalg named ops support was added, captures were omitted from the body builder. This revision adds support for captures which allows us to write FillOp in a more idiomatic fashion using the _linalg_ops_ext mixin support. This raises an issue in the generation of `_linalg_ops_gen.py` where ``` @property def result(self): return self.operation.results[0] if len(self.operation.results) > 1 else None ```. The condition should be `== 1`. This will be fixed in a separate commit. Differential Revision: https://reviews.llvm.org/D100363 --- mlir/include/mlir-c/Dialect/Linalg.h | 8 ++-- mlir/lib/Bindings/Python/DialectLinalg.cpp | 11 +++-- .../Python/mlir/dialects/_linalg_ops_ext.py | 41 +++++++++++++++++++ mlir/lib/CAPI/Dialect/Linalg.cpp | 16 ++++++-- 4 files changed, 67 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 06f15f062..6e20eec16 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -1,11 +1,11 @@ -//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect --------*- C -*-===// +//===-- mlir-c/Dialect/Linalg.h - C API for Linalg dialect -------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -//===----------------------------------------------------------------------===// +//===---------------------------------------------------------------------===// #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H @@ -18,9 +18,11 @@ extern "C" { #endif /// Apply the special region builder for the builtin named Linalg op. +/// The list of `capture` MlirValue is passed as-is to the region builder. /// Assert that `op` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op, + intptr_t n, MlirValue const *mlirCaptures); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index e4ef69411..849a0039a 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -22,10 +22,15 @@ namespace python { void populateDialectLinalgSubmodule(py::module &m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op) { - return mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); + [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) { + llvm::SmallVector mlirOperands; + mlirOperands.reserve(captures.size()); + for (auto v : captures) + mlirOperands.push_back(py::cast(v)->get()); + mlirLinalgFillBuiltinNamedOpRegion( + dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data()); }, - py::arg("dialect"), py::arg("op"), + py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py index d787943d1..4714e69b3 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py @@ -5,6 +5,47 @@ from typing import Optional, Sequence, Union from ..ir import * from ._ods_common import get_default_loc_context +# TODO: resolve name collision for Linalg functionality that is injected inside +# the _mlir.dialects.linalg directly via pybind. +from _mlir.dialects.linalg import fill_builtin_region + + +def isa(cls : Type, ty : Type): + try: + cls(ty) + return True + except ValueError: + return False + + +class FillOp: + """Extends the linalg.fill op.""" + + def __init__(self, + output: Value, + value: Value, + *, + loc=None, + ip=None): + results = [] + if isa(RankedTensorType, output.type): + results = [output.type] + op = self.build_generic(results=results, + operands=[output, value], + attributes=None, + loc=loc, + ip=ip) + OpView.__init__(self, op) + linalgDialect = Context.current.get_dialect_descriptor("linalg") + fill_builtin_region(linalgDialect, self.operation, [value]) + # TODO: self.result is None. When len(results) == 1 we expect it to be + # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug + # in the generator of _linalg_ops_gen.py where we have: + # ``` + # def result(self): + # return self.operation.results[0] \ + # if len(self.operation.results) > 1 else None + # ``` class InitTensorOp: diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 1c50aa612..6f6e090d7 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,6 +8,7 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" using namespace mlir; @@ -16,8 +17,14 @@ using namespace mlir::linalg; /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp) { + MlirOperation mlirOp, intptr_t n, + MlirValue const *mlirCaptures) { Operation *op = unwrap(mlirOp); + SmallVector captures; + captures.reserve(n); + for (unsigned idx = 0; idx < n; ++idx) + captures.push_back(unwrap(mlirCaptures[idx])); + LinalgDialect::RegionBuilderFunType fun = static_cast(unwrap(linalgDialect)) ->getRegionBuilder(op->getName().getStringRef()); @@ -25,15 +32,18 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); assert(op->getRegion(0).getBlocks().empty() && "Expected Linalg op with 0 blocks"); + SmallVector argTypes; auto linalgOp = cast(op); for (auto t : linalgOp.getShapedOperandTypes()) argTypes.push_back(getElementTypeOrSelf(t)); + OpBuilder b(op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); - // TODO: allow captures. - fun(*body, ValueRange{}); + b.setInsertionPointToStart(body); + mlir::edsc::ScopedContext scope(b, op->getLoc()); + fun(*body, captures); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From d38cd4cc3ddbe1668cc9352d37d6b72ab5460b96 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 16 Apr 2021 12:54:43 +0000 Subject: [PATCH 021/915] [mlir][python] Add simple debugging and printing helpers Differential Revision: https://reviews.llvm.org/D100643 --- mlir/include/mlir-c/IR.h | 11 +++++++++++ mlir/include/mlir-c/Pass.h | 8 ++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/Bindings/Python/Pass.cpp | 12 ++++++++++++ mlir/lib/Bindings/Python/mlir/ir.py | 5 +++++ mlir/lib/CAPI/IR/IR.cpp | 12 ++++++++++++ mlir/lib/CAPI/IR/Pass.cpp | 8 ++++++++ 7 files changed, 65 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 048bd4667..c64ec174d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -75,6 +75,13 @@ struct MlirNamedAttribute { }; typedef struct MlirNamedAttribute MlirNamedAttribute; +//===----------------------------------------------------------------------===// +// Global API. +//===----------------------------------------------------------------------===// + +/// Set the global debugging flag. +MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); + //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// @@ -119,6 +126,10 @@ mlirContextGetNumLoadedDialects(MlirContext context); MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); +/// Set threading mode (must be set to false to print-ir-after-all). +MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, + bool enable); + /// Returns whether the given fully-qualified operation (i.e. /// 'dialect.operation') is registered with the context. This will return true /// if the dialect is loaded and the operation is registered within the diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 9669a53cd..d8b216812 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -65,6 +65,14 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module); +/// Enable print-ir-after-all. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableIRPrinting(MlirPassManager passManager); + +/// Enable / disable verify-each. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); + /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. /// The returned OpPassManager will be destroyed when the parent is destroyed. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7a7bae92c..0f3a1c0dc 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1712,6 +1712,11 @@ class PyOpAttributeMap { //------------------------------------------------------------------------------ void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of Global functions + //---------------------------------------------------------------------------- + m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); }); + //---------------------------------------------------------------------------- // Mapping of MlirContext //---------------------------------------------------------------------------- @@ -1766,6 +1771,10 @@ void mlir::python::populateIRCore(py::module &m) { [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) + .def("enable_multithreading", + [](PyMlirContext &self, bool enable) { + mlirContextEnableMultithreading(self.get(), enable); + }) .def("is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 0e2f5bafb..f2433573b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -68,6 +68,18 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") + .def( + "enable_ir_printing", + [](PyPassManager &passManager) { + mlirPassManagerEnableIRPrinting(passManager.get()); + }, + "Enable print-ir-after-all.") + .def( + "enable_verifier", + [](PyPassManager &passManager, bool enable) { + mlirPassManagerEnableVerifier(passManager.get(), enable); + }, + "Enable / disable verify-each.") .def_static( "parse", [](const std::string pipeline, DefaultingPyMlirContext context) { diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py index e5ba1bdb0..e2c785c50 100644 --- a/mlir/lib/Bindings/Python/mlir/ir.py +++ b/mlir/lib/Bindings/Python/mlir/ir.py @@ -6,3 +6,8 @@ from ._cext_loader import _reexport_cext _reexport_cext("ir", __name__) del _reexport_cext + +# Extra functions that are not visible to _reexport_cext. +# TODO: is this really necessary? +from _mlir.ir import _enable_debug +_enable_debug = _enable_debug \ No newline at end of file diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 14cde9633..616caae1e 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -21,8 +21,16 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser.h" +#include "llvm/Support/Debug.h" + using namespace mlir; +//===----------------------------------------------------------------------===// +// Global API. +//===----------------------------------------------------------------------===// + +void mlirEnableGlobalDebug(bool enable) { ::llvm::DebugFlag = true; } + //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// @@ -64,6 +72,10 @@ bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name) { return unwrap(context)->isOperationRegistered(unwrap(name)); } +void mlirContextEnableMultithreading(MlirContext context, bool enable) { + return unwrap(context)->enableMultithreading(enable); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b3685ddf4..4bfc9d013 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -38,6 +38,14 @@ MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, return wrap(unwrap(passManager)->run(unwrap(module))); } +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { + return unwrap(passManager)->enableIRPrinting(); +} + +void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { + unwrap(passManager)->enableVerifier(enable); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); From 0e2f2f0907f8b2b3491eac2240a366b0a080727c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 19 Apr 2021 13:37:01 +0200 Subject: [PATCH 022/915] [mlir] Improve debug flag management in Python bindings Expose the debug flag as a readable and assignable property of a dedicated class instead of a write-only function. Actually test the fact of setting the flag. Move test to a dedicated file, it has zero relation to context_managers.py where it was added. Arguably, it should be promoted from mlir.ir to mlir module, but we are not re-exporting the latter and this functionality is purposefully hidden so can stay in IR for now. Drop unnecessary export code. Refactor C API and put Debug into a separate library, fix it to actually set the flag to the given value. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D100757 --- mlir/include/mlir-c/Debug.h | 30 +++++++++++++++++++++++++++++ mlir/include/mlir-c/IR.h | 7 ------- mlir/lib/Bindings/Python/IRCore.cpp | 27 +++++++++++++++++++------- mlir/lib/Bindings/Python/mlir/ir.py | 4 ---- mlir/lib/CAPI/CMakeLists.txt | 1 + mlir/lib/CAPI/Debug/CMakeLists.txt | 6 ++++++ mlir/lib/CAPI/Debug/Debug.cpp | 18 +++++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 6 ------ 8 files changed, 75 insertions(+), 24 deletions(-) create mode 100644 mlir/include/mlir-c/Debug.h create mode 100644 mlir/lib/CAPI/Debug/CMakeLists.txt create mode 100644 mlir/lib/CAPI/Debug/Debug.cpp diff --git a/mlir/include/mlir-c/Debug.h b/mlir/include/mlir-c/Debug.h new file mode 100644 index 000000000..2502f2fa2 --- /dev/null +++ b/mlir/include/mlir-c/Debug.h @@ -0,0 +1,30 @@ +//===-- mlir-c/Debug.h - C API for MLIR/LLVM debugging functions --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Support.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Sets the global debugging flag. +MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); + +/// Retuns `true` if the global debugging flag is set, false otherwise. +MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled(); + +#ifdef __cplusplus +} +#endif + +#ifndef MLIR_C_DEBUG_H +#define MLIR_C_DEBUG_H +#endif // MLIR_C_DEBUG_H diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index c64ec174d..8e92510ae 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -75,13 +75,6 @@ struct MlirNamedAttribute { }; typedef struct MlirNamedAttribute MlirNamedAttribute; -//===----------------------------------------------------------------------===// -// Global API. -//===----------------------------------------------------------------------===// - -/// Set the global debugging flag. -MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); - //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0f3a1c0dc..a2655d9c4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -14,6 +14,7 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Debug.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -129,7 +130,7 @@ equivalent to printing the operation that produced it. // Utilities. //------------------------------------------------------------------------------ -// Helper for creating an @classmethod. +/// Helper for creating an @classmethod. template py::object classmethod(Func f, Args... args) { py::object cf = py::cpp_function(f, args...); @@ -153,6 +154,20 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +/// Wrapper for the global LLVM debugging flag. +struct PyGlobalDebugFlag { + static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + + static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } + + static void bind(py::module &m) { + // Debug flags. + py::class_(m, "_GlobalDebug") + .def_property_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -1713,12 +1728,7 @@ class PyOpAttributeMap { void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- - // Mapping of Global functions - //---------------------------------------------------------------------------- - m.def("_enable_debug", [](bool enable) { mlirEnableGlobalDebug(enable); }); - - //---------------------------------------------------------------------------- - // Mapping of MlirContext + // Mapping of MlirContext. //---------------------------------------------------------------------------- py::class_(m, "Context") .def(py::init<>(&PyMlirContext::createNewContextForInit)) @@ -2384,4 +2394,7 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResultList::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); + + // Debug bindings. + PyGlobalDebugFlag::bind(m); } diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/lib/Bindings/Python/mlir/ir.py index e2c785c50..2b420511d 100644 --- a/mlir/lib/Bindings/Python/mlir/ir.py +++ b/mlir/lib/Bindings/Python/mlir/ir.py @@ -7,7 +7,3 @@ _reexport_cext("ir", __name__) del _reexport_cext -# Extra functions that are not visible to _reexport_cext. -# TODO: is this really necessary? -from _mlir.ir import _enable_debug -_enable_debug = _enable_debug \ No newline at end of file diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index ba58d99a7..db77cc1f6 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(Debug) add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(ExecutionEngine) diff --git a/mlir/lib/CAPI/Debug/CMakeLists.txt b/mlir/lib/CAPI/Debug/CMakeLists.txt new file mode 100644 index 000000000..fdffe304d --- /dev/null +++ b/mlir/lib/CAPI/Debug/CMakeLists.txt @@ -0,0 +1,6 @@ +add_mlir_public_c_api_library(MLIRCAPIDebug + Debug.cpp + + LINK_LIBS PUBLIC + MLIRSupport +) diff --git a/mlir/lib/CAPI/Debug/Debug.cpp b/mlir/lib/CAPI/Debug/Debug.cpp new file mode 100644 index 000000000..288ecd601 --- /dev/null +++ b/mlir/lib/CAPI/Debug/Debug.cpp @@ -0,0 +1,18 @@ +//===- Debug.cpp - C Interface for MLIR/LLVM Debugging Functions ----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Debug.h" +#include "mlir-c/Support.h" + +#include "mlir/CAPI/Support.h" + +#include "llvm/Support/Debug.h" + +void mlirEnableGlobalDebug(bool enable) { llvm::DebugFlag = enable; } + +bool mlirIsGlobalDebugEnabled() { return llvm::DebugFlag; } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 616caae1e..000b8f565 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -25,12 +25,6 @@ using namespace mlir; -//===----------------------------------------------------------------------===// -// Global API. -//===----------------------------------------------------------------------===// - -void mlirEnableGlobalDebug(bool enable) { ::llvm::DebugFlag = true; } - //===----------------------------------------------------------------------===// // Context API. //===----------------------------------------------------------------------===// From 984c383147b0a558906e7826ced2c492ef276f33 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 19 Apr 2021 19:30:29 +0000 Subject: [PATCH 023/915] [mlir][python] ExecutionEngine can dump to object file Differential Revision: https://reviews.llvm.org/D100786 --- mlir/include/mlir-c/ExecutionEngine.h | 5 +++++ mlir/lib/Bindings/Python/ExecutionEngine.cpp | 10 +++++++++- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 5 +++++ 3 files changed, 19 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 5210f108e..4a5d6ad9f 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -67,6 +67,11 @@ MLIR_CAPI_EXPORTED void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, MlirStringRef name, void *sym); +/// Dump as an object in `fileName`. +MLIR_CAPI_EXPORTED void +mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit, + MlirStringRef fileName); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp index 0e8ae8b38..b5c8dde75 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -93,5 +93,13 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) { mlirStringRefCreate(name.c_str(), name.size()), reinterpret_cast(sym)); }, - "Lookup function `func` in the ExecutionEngine."); + "Lookup function `func` in the ExecutionEngine.") + .def( + "dump_to_object_file", + [](PyExecutionEngine &executionEngine, const std::string &fileName) { + mlirExecutionEngineDumpToObjectFile( + executionEngine.get(), + mlirStringRefCreate(fileName.c_str(), fileName.size())); + }, + "Dump ExecutionEngine to an object file."); } diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 345eac219..36f24ed88 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -66,3 +66,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, return symbolMap; }); } + +extern "C" void mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit, + MlirStringRef name) { + unwrap(jit)->dumpToObjectFile(unwrap(name)); +} From bccbd3bb76f7f0f7a1652a5a7942928e60579065 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 22 Apr 2021 15:52:01 +0200 Subject: [PATCH 024/915] [mlir] Move PyConcreteAttribute to header. NFC. This allows out-of-tree users to derive PyConcreteAttribute to bind custom attributes. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D101063 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 40 ----------------------- mlir/lib/Bindings/Python/IRModule.h | 40 +++++++++++++++++++++++ 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index b5e3c5c9c..0af762d93 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -27,46 +27,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -/// CRTP base classes for Python attributes that subclass Attribute and should -/// be castable from it (i.e. via something like StringAttr(attr)). -/// By default, attribute class hierarchies are one level deep (i.e. a -/// concrete attribute class extends PyAttribute); however, intermediate -/// python-visible base classes can be modeled by specifying a BaseTy. -template -class PyConcreteAttribute : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirAttribute); - - PyConcreteAttribute() = default; - PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) - : BaseTy(std::move(contextRef), attr) {} - PyConcreteAttribute(PyAttribute &orig) - : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} - - static MlirAttribute castFrom(PyAttribute &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::buffer_protocol()); - cls.def(py::init(), py::keep_alive<0, 1>()); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 861673abc..f3f5ee5ed 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -642,6 +642,46 @@ class PyNamedAttribute { std::unique_ptr ownedName; }; +/// CRTP base classes for Python attributes that subclass Attribute and should +/// be castable from it (i.e. via something like StringAttr(attr)). +/// By default, attribute class hierarchies are one level deep (i.e. a +/// concrete attribute class extends PyAttribute); however, intermediate +/// python-visible base classes can be modeled by specifying a BaseTy. +template +class PyConcreteAttribute : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirAttribute); + + PyConcreteAttribute() = default; + PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) + : BaseTy(std::move(contextRef), attr) {} + PyConcreteAttribute(PyAttribute &orig) + : PyConcreteAttribute(orig.getContext(), castFrom(orig)) {} + + static MlirAttribute castFrom(PyAttribute &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, + llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + /// Wrapper around the generic MlirType. /// The lifetime of a type is bound by the PyContext that created it. class PyType : public BaseContextObject { From 8b0065e21e750ccbccc39e87e99c5491504a66ec Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Thu, 22 Apr 2021 00:07:30 -0600 Subject: [PATCH 025/915] [MLIR][Python] Add capsule methods for pybind11 to PyValue. Add the `getCapsule()` and `createFromCapsule()` methods to the PyValue class, as well as the necessary interoperability. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D101090 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 20 ++++++++++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 24 +++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 7 ++++++ 3 files changed, 51 insertions(+) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index d85315946..882f73d84 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -42,6 +42,7 @@ #define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" #define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr" #define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr" +#define MLIR_PYTHON_CAPSULE_VALUE "mlir.ir.Value._CAPIPtr" /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -285,6 +286,25 @@ mlirPythonCapsuleToExecutionEngine(PyObject *capsule) { return jit; } +/** Creates a capsule object encapsulating the raw C-API MlirValue. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the operation in any way. + */ +static inline PyObject *mlirPythonValueToCapsule(MlirValue value) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(value), + MLIR_PYTHON_CAPSULE_VALUE, NULL); +} + +/** Extracts an MlirValue from a capsule as produced from + * mlirPythonValueToCapsule. If the capsule is not of the right type, then a + * null type is returned (as checked via mlirValueIsNull). In such a case, the + * Python APIs will have already set an error. */ +static inline MlirValue mlirPythonCapsuleToValue(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_VALUE); + MlirValue value = {ptr}; + return value; +} + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index a2655d9c4..b93786e05 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/IR.h" #include "mlir-c/Registration.h" #include "llvm/ADT/SmallVector.h" #include @@ -1467,6 +1468,27 @@ PyType PyType::createFromCapsule(py::object capsule) { // PyValue and subclases. //------------------------------------------------------------------------------ +pybind11::object PyValue::getCapsule() { + return py::reinterpret_steal(mlirPythonValueToCapsule(get())); +} + +PyValue PyValue::createFromCapsule(pybind11::object capsule) { + MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); + if (mlirValueIsNull(value)) + throw py::error_already_set(); + MlirOperation owner; + if (mlirValueIsAOpResult(value)) + owner = mlirOpResultGetOwner(value); + if (mlirValueIsABlockArgument(value)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); + if (mlirOperationIsNull(owner)) + throw py::error_already_set(); + MlirContext ctx = mlirOperationGetContext(owner); + PyOperationRef ownerRef = + PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); + return PyValue(ownerRef, value); +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2353,6 +2375,8 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of Value. //---------------------------------------------------------------------------- py::class_(m, "Value") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f3f5ee5ed..ff3faeefd 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -721,6 +721,13 @@ class PyValue { void checkValid() { return parentOperation->checkValid(); } + /// Gets a capsule wrapping the void* within the MlirValue. + pybind11::object getCapsule(); + + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of + /// the underlying MlirValue is still tied to the owning operation. + static PyValue createFromCapsule(pybind11::object capsule); + private: PyOperationRef parentOperation; MlirValue value; From 7ed37244d3c3535ce645a0b88d8231d1fa187770 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Fri, 23 Apr 2021 20:27:43 -0600 Subject: [PATCH 026/915] [mlir] Support setting operand values in C and Python APIs. This adds `mlirOperationSetOperand` to the IR C API, similar to the function to get an operand. In the Python API, this adds `operands[index] = value` syntax, similar to the syntax to get an operand with `operands[index]`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D101398 --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/Bindings/Python/PybindUtils.h | 17 +++++++++++------ mlir/lib/CAPI/IR/IR.cpp | 5 +++++ 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 8e92510ae..1b243165c 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -366,6 +366,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumOperands(MlirOperation op); MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos); +/// Sets the `pos`-th operand of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue); + /// Returns the number of results of the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b93786e05..0945753f9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1640,6 +1640,15 @@ class PyOpOperandList : public Sliceable { return PyOpOperandList(operation, startIndex, length, step); } + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + private: PyOperationRef operation; }; diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 0cea24482..7a9b8ecb9 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -215,6 +215,16 @@ class Sliceable { protected: using ClassTy = pybind11::class_; + intptr_t wrapIndex(intptr_t index) { + if (index < 0) + index = length + index; + if (index < 0 || index >= length) { + throw python::SetPyError(PyExc_IndexError, + "attempt to access out of bounds"); + } + return index; + } + public: explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) : startIndex(startIndex), length(length), step(step) { @@ -228,12 +238,7 @@ class Sliceable { /// by taking elements in inverse order. Throws if the index is out of bounds. ElementTy dunderGetItem(intptr_t index) { // Negative indices mean we count from the end. - if (index < 0) - index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } + index = wrapIndex(index); // Compute the linear index given the current slice properties. int linearIndex = index * step + startIndex; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 000b8f565..4e2183516 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -351,6 +351,11 @@ MlirValue mlirOperationGetOperand(MlirOperation op, intptr_t pos) { return wrap(unwrap(op)->getOperand(static_cast(pos))); } +void mlirOperationSetOperand(MlirOperation op, intptr_t pos, + MlirValue newValue) { + unwrap(op)->setOperand(static_cast(pos), unwrap(newValue)); +} + intptr_t mlirOperationGetNumResults(MlirOperation op) { return static_cast(unwrap(op)->getNumResults()); } From 988a2858b2e4fd3066c3e015cbf7278a85c8ae06 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 28 Apr 2021 07:38:36 +0000 Subject: [PATCH 027/915] [mlir][Python][Linalg] Fixing typos (NFC). --- .../Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 229458855..b52a0e2d6 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,7 @@ def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplacation of two 2D inputs. + """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -23,7 +23,7 @@ def matmul(A=TensorDef(T1, S.M, S.K), def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N), C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplacation of two 3D inputs. + """Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. @@ -49,7 +49,7 @@ def matvec(A=TensorDef(T1, S.M, S.N), def vecmat(y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True)): - """Performs a vector-matrix multiplacation. + """Performs a vector-matrix multiplication. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. From 3e967c6b2c6cb25e5fd29564127e7d6895767f82 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 27 Apr 2021 19:57:56 +0000 Subject: [PATCH 028/915] [mlir][python] Add python support for async dialect and passes. since the `async` keyword is reserved in python, the dialect is called async_dialect. Differential Revision: https://reviews.llvm.org/D101447 --- mlir/include/mlir-c/Dialect/Async.h | 28 +++++++++++++++++++ mlir/lib/Bindings/Python/AsyncOps.td | 15 ++++++++++ mlir/lib/Bindings/Python/AsyncPasses.cpp | 22 +++++++++++++++ mlir/lib/Bindings/Python/CMakeLists.txt | 13 +++++++++ .../mlir/dialects/async_dialect/__init__.py | 5 ++++ .../dialects/async_dialect/passes/__init__.py | 6 ++++ mlir/lib/CAPI/Dialect/Async.cpp | 13 +++++++++ mlir/lib/CAPI/Dialect/AsyncPasses.cpp | 26 +++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 16 +++++++++++ 9 files changed, 144 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/Async.h create mode 100644 mlir/lib/Bindings/Python/AsyncOps.td create mode 100644 mlir/lib/Bindings/Python/AsyncPasses.cpp create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py create mode 100644 mlir/lib/CAPI/Dialect/Async.cpp create mode 100644 mlir/lib/CAPI/Dialect/AsyncPasses.cpp diff --git a/mlir/include/mlir-c/Dialect/Async.h b/mlir/include/mlir-c/Dialect/Async.h new file mode 100644 index 000000000..50b6413ef --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Async.h @@ -0,0 +1,28 @@ +//===-- mlir-c/Dialect/Async.h - C API for Async dialect ---------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_ASYNC_H +#define MLIR_C_DIALECT_ASYNC_H + +#include "mlir-c/Registration.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Async, async); + +#ifdef __cplusplus +} +#endif + +#include "mlir/Dialect/Async/Passes.capi.h.inc" + +#endif // MLIR_C_DIALECT_ASYNC_H diff --git a/mlir/lib/Bindings/Python/AsyncOps.td b/mlir/lib/Bindings/Python/AsyncOps.td new file mode 100644 index 000000000..b65b9bafd --- /dev/null +++ b/mlir/lib/Bindings/Python/AsyncOps.td @@ -0,0 +1,15 @@ +//===-- AsyncOps.td - Entry point async_dialect bindings --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ASYNC_OPS +#define PYTHON_BINDINGS_ASYNC_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Async/IR/AsyncOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp new file mode 100644 index 000000000..2b83ed40d --- /dev/null +++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp @@ -0,0 +1,22 @@ +//===- AsyncPasses.cpp - Pybind module for the Async passes -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Async.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirAsyncPasses, m) { + m.doc() = "MLIR Async Dialect Passes"; + + // Register all Async passes on load. + mlirRegisterAsyncPasses(); +} diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 39192cc54..eba4d2886 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -31,6 +31,11 @@ endforeach() # Generate dialect-specific bindings. ################################################################################ +add_mlir_dialect_python_bindings(MLIRBindingsPythonAsyncOps + TD_FILE AsyncOps.td + DIALECT_NAME async_dialect) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonAsyncOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps TD_FILE BuiltinOps.td DIALECT_NAME builtin) @@ -120,6 +125,14 @@ endif() add_subdirectory(Transforms) add_subdirectory(Conversions) +add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasses + INSTALL_DIR + python + SOURCES + AsyncPasses.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) + add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses INSTALL_DIR python diff --git a/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py new file mode 100644 index 000000000..dcf9d6cb2 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._async_dialect_ops_gen import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py new file mode 100644 index 000000000..88a7b539c --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._cext_loader import _load_extension +_cextAsyncPasses = _load_extension("_mlirAsyncPasses") diff --git a/mlir/lib/CAPI/Dialect/Async.cpp b/mlir/lib/CAPI/Dialect/Async.cpp new file mode 100644 index 000000000..182cbf9df --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Async.cpp @@ -0,0 +1,13 @@ +//===- Async.cpp - C Interface for Async dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Async/IR/Async.h" +#include "mlir-c/Dialect/Async.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Async, async, mlir::async::AsyncDialect) diff --git a/mlir/lib/CAPI/Dialect/AsyncPasses.cpp b/mlir/lib/CAPI/Dialect/AsyncPasses.cpp new file mode 100644 index 000000000..aa2074dcd --- /dev/null +++ b/mlir/lib/CAPI/Dialect/AsyncPasses.cpp @@ -0,0 +1,26 @@ +//===- AsyncPasses.cpp - C API for Async Dialect Passes -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/Async/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/Async/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/Async/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 41c659d6a..3f6265e8a 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,5 +1,7 @@ # TODO: Make the check source feature optional as an argument on *_add_library. set(LLVM_OPTIONAL_SOURCES + Async.cpp + AsyncPasses.cpp Linalg.cpp LinalgPasses.cpp SCF.cpp @@ -8,6 +10,20 @@ set(LLVM_OPTIONAL_SOURCES Tensor.cpp ) +add_mlir_public_c_api_library(MLIRCAPIAsync + Async.cpp + AsyncPasses.cpp + + DEPENDS + MLIRAsyncPassIncGen + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRAsync + MLIRAsyncTransforms + MLIRPass +) + add_mlir_public_c_api_library(MLIRCAPILinalg Linalg.cpp LinalgPasses.cpp From 9cf1c9373283eae88f3db09181a609eca02df4e6 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 28 Apr 2021 13:22:34 +0000 Subject: [PATCH 029/915] [mlir][python] Add basic python support for GPU dialect and passes Differential Revision: https://reviews.llvm.org/D101449 --- mlir/include/mlir-c/Dialect/GPU.h | 28 +++++++++++++++++++ mlir/lib/Bindings/Python/CMakeLists.txt | 13 +++++++++ mlir/lib/Bindings/Python/GPUOps.td | 15 ++++++++++ mlir/lib/Bindings/Python/GPUPasses.cpp | 22 +++++++++++++++ .../Python/mlir/dialects/gpu/__init__.py | 5 ++++ .../mlir/dialects/gpu/passes/__init__.py | 6 ++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 15 ++++++++++ mlir/lib/CAPI/Dialect/GPU.cpp | 13 +++++++++ mlir/lib/CAPI/Dialect/GPUPasses.cpp | 26 +++++++++++++++++ 9 files changed, 143 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/GPU.h create mode 100644 mlir/lib/Bindings/Python/GPUOps.td create mode 100644 mlir/lib/Bindings/Python/GPUPasses.cpp create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py create mode 100644 mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py create mode 100644 mlir/lib/CAPI/Dialect/GPU.cpp create mode 100644 mlir/lib/CAPI/Dialect/GPUPasses.cpp diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h new file mode 100644 index 000000000..e4797a7ee --- /dev/null +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -0,0 +1,28 @@ +//===-- mlir-c/Dialect/GPU.h - C API for GPU dialect -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_GPU_H +#define MLIR_C_DIALECT_GPU_H + +#include "mlir-c/Registration.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu); + +#ifdef __cplusplus +} +#endif + +#include "mlir/Dialect/GPU/Passes.capi.h.inc" + +#endif // MLIR_C_DIALECT_GPU_H diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index eba4d2886..bbccea63c 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -41,6 +41,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps DIALECT_NAME builtin) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonGPUOps + TD_FILE GPUOps.td + DIALECT_NAME gpu) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonGPUOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps TD_FILE LinalgOps.td DIALECT_NAME linalg @@ -133,6 +138,14 @@ add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasse ) add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) +add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses + INSTALL_DIR + python + SOURCES + GPUPasses.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRGPUPassesBindingsPythonExtension) + add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses INSTALL_DIR python diff --git a/mlir/lib/Bindings/Python/GPUOps.td b/mlir/lib/Bindings/Python/GPUOps.td new file mode 100644 index 000000000..bf0980f29 --- /dev/null +++ b/mlir/lib/Bindings/Python/GPUOps.td @@ -0,0 +1,15 @@ +//===-- GPUOps.td - Entry point GPU_dialect bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_GPU_OPS +#define PYTHON_BINDINGS_GPU_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/GPU/GPUOps.td" + +#endif diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp new file mode 100644 index 000000000..cb623a11b --- /dev/null +++ b/mlir/lib/Bindings/Python/GPUPasses.cpp @@ -0,0 +1,22 @@ +//===- GPUPasses.cpp - Pybind module for the GPU passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirGPUPasses, m) { + m.doc() = "MLIR GPU Dialect Passes"; + + // Register all GPU passes on load. + mlirRegisterGPUPasses(); +} diff --git a/mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py new file mode 100644 index 000000000..67bf7bd85 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._gpu_ops_gen import * diff --git a/mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py b/mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py new file mode 100644 index 000000000..dd28e91a4 --- /dev/null +++ b/mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ...._cext_loader import _load_extension +_cextGPUPasses = _load_extension("_mlirGPUPasses") diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 3f6265e8a..dd9bd6f67 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -2,6 +2,8 @@ set(LLVM_OPTIONAL_SOURCES Async.cpp AsyncPasses.cpp + GPU.cpp + GPUPasses.cpp Linalg.cpp LinalgPasses.cpp SCF.cpp @@ -24,6 +26,19 @@ add_mlir_public_c_api_library(MLIRCAPIAsync MLIRPass ) +add_mlir_public_c_api_library(MLIRCAPIGPU + GPU.cpp + GPUPasses.cpp + + DEPENDS + MLIRGPUPassIncGen + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRGPU + MLIRPass +) + add_mlir_public_c_api_library(MLIRCAPILinalg Linalg.cpp LinalgPasses.cpp diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp new file mode 100644 index 000000000..0de2cfa33 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -0,0 +1,13 @@ +//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/GPU/GPUDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect) diff --git a/mlir/lib/CAPI/Dialect/GPUPasses.cpp b/mlir/lib/CAPI/Dialect/GPUPasses.cpp new file mode 100644 index 000000000..4ec167f88 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/GPUPasses.cpp @@ -0,0 +1,26 @@ +//===- GPUPasses.cpp - C API for GPU Dialect Passes ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/GPU/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/GPU/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif From ea42e2f5b465986c2c114ac33879d8b1c43ec6bb Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Fri, 23 Apr 2021 20:32:54 -0600 Subject: [PATCH 030/915] [mlir][python] Update `PyOpResult.owner` to get the parent object. Previously, this API would return the PyObjectRef, rather than the underlying PyOperation. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D101416 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0945753f9..781e9aed6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1566,7 +1566,7 @@ class PyOpResult : public PyConcreteValue { mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation(); + return self.getParentOperation().getObject(); }); c.def_property_readonly("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); From fd3b1b4797c13c8df260742e98b830d706fc0d15 Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 28 Apr 2021 16:16:45 -0700 Subject: [PATCH 031/915] [mlir] Move PyConcreteType to header. NFC. This allows out-of-tree users to derive PyConcreteType to bind custom types. The Type version of https://reviews.llvm.org/D101063/new/ Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D101496 --- mlir/lib/Bindings/Python/IRModule.h | 43 ++++++++++++++++++++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 43 ---------------------------- 2 files changed, 43 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index ff3faeefd..292080d91 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -705,6 +705,49 @@ class PyType : public BaseContextObject { MlirType type; }; +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + /// Wrapper around the generic MlirValue. /// Values are managed completely by the operation that resulted in their /// definition. For op result value, this is the operation that defines the diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 421df4dab..b6875c76e 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -28,49 +28,6 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) { mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); } -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = py::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(py::init(), py::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; From edbcad98f7744c835d37de6da188b1efb840e20f Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Fri, 23 Apr 2021 20:54:04 -0600 Subject: [PATCH 032/915] [mlir][python] Add `destroy` method to PyOperation. This adds a method to directly invoke `mlirOperationDestroy` on the MlirOperation wrapped by a PyOperation. Reviewed By: stellaraccident, mehdi_amini Differential Revision: https://reviews.llvm.org/D101422 --- mlir/lib/Bindings/Python/IRCore.cpp | 26 +++++++++++++++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 4 ++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 781e9aed6..160e35b21 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -753,6 +753,9 @@ PyOperation::PyOperation(PyMlirContextRef contextRef, MlirOperation operation) : BaseContextObject(std::move(contextRef)), operation(operation) {} PyOperation::~PyOperation() { + // If the operation has already been invalidated there is nothing to do. + if (!valid) + return; auto &liveOperations = getContext()->liveOperations; assert(liveOperations.count(operation.ptr) == 1 && "destroying operation not in live map"); @@ -869,6 +872,7 @@ py::object PyOperationBase::getAsm(bool binary, } PyOperationRef PyOperation::getParentOperation() { + checkValid(); if (!isAttached()) throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); @@ -878,6 +882,7 @@ PyOperationRef PyOperation::getParentOperation() { } PyBlock PyOperation::getBlock() { + checkValid(); PyOperationRef parentOperation = getParentOperation(); MlirBlock block = mlirOperationGetBlock(get()); assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); @@ -885,6 +890,7 @@ PyBlock PyOperation::getBlock() { } py::object PyOperation::getCapsule() { + checkValid(); return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); } @@ -1032,6 +1038,7 @@ py::object PyOperation::create( } py::object PyOperation::createOpView() { + checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto opViewClass = PyGlobals::get().lookupRawOpViewClass( @@ -1041,6 +1048,18 @@ py::object PyOperation::createOpView() { return py::cast(PyOpView(getRef().getObject())); } +void PyOperation::erase() { + checkValid(); + // TODO: Fix memory hazards when erasing a tree of operations for which a deep + // Python reference to a child operation is live. All children should also + // have their `valid` bit set to false. + auto &liveOperations = getContext()->liveOperations; + if (liveOperations.count(operation.ptr)) + liveOperations.erase(operation.ptr); + mlirOperationDestroy(operation); + valid = false; +} + //------------------------------------------------------------------------------ // PyOpView //------------------------------------------------------------------------------ @@ -2094,11 +2113,13 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def("erase", &PyOperation::erase) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_property_readonly("name", [](PyOperation &self) { + self.checkValid(); MlirOperation operation = self.get(); MlirStringRef name = mlirIdentifierStr( mlirOperationGetName(operation)); @@ -2106,7 +2127,10 @@ void mlir::python::populateIRCore(py::module &m) { }) .def_property_readonly( "context", - [](PyOperation &self) { return self.getContext().getObject(); }, + [](PyOperation &self) { + self.checkValid(); + return self.getContext().getObject(); + }, "Context that owns the Operation") .def_property_readonly("opview", &PyOperation::createOpView); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 292080d91..79c480e94 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -473,6 +473,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an OpView suitable for this operation. pybind11::object createOpView(); + /// Erases the underlying MlirOperation, removes its pointer from the + /// parent context's live operations map, and sets the valid bit false. + void erase(); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, From 88b6d902578364d0f5399f47fa4b41e99fbc984c Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 29 Apr 2021 06:45:34 +0000 Subject: [PATCH 033/915] [mlir][Python][Linalg] Adding const, capture, and index support to the OpDSL. The patch extends the OpDSL with support for: - Constant values - Capture scalar parameters - Access the iteration indices using the index operation - Provide predefined floating point and integer types. Up to now the patch only supports emitting the new nodes. The C++/yaml path is not fully implemented. The fill_rng_2d operation defined in emit_structured_generic.py makes use of the new DSL constructs. Differential Revision: https://reviews.llvm.org/D101364 --- .../mlir/dialects/linalg/opdsl/lang/affine.py | 2 - .../linalg/opdsl/lang/comprehension.py | 160 +++++++++++++----- .../mlir/dialects/linalg/opdsl/lang/config.py | 49 +++++- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 10 +- .../dialects/linalg/opdsl/lang/emitter.py | 132 +++++++++++---- .../dialects/linalg/opdsl/lang/scalar_expr.py | 64 ++++++- .../mlir/dialects/linalg/opdsl/lang/types.py | 12 ++ 7 files changed, 349 insertions(+), 80 deletions(-) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py index 34a8d6d30..6db3bcfcc 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -232,7 +232,6 @@ class DimDef(AffineExprDef): """ ALL_DIMS = dict() # type: Dict[str, "DimDef"] - dimname: str def __new__(cls, dimname: str): existing = cls.ALL_DIMS.get(dimname) @@ -276,7 +275,6 @@ class SymbolDef(AffineExprDef): True """ ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] - symname: str def __new__(cls, symname: str): existing = cls.ALL_SYMBOLS.get(symname) diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 85da3323c..9b93d33b3 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union from mlir import ir as _ir @@ -27,24 +27,49 @@ class TensorExpression: def to_scalar_expression(self) -> ScalarExpression: raise NotImplementedError() - def visit_affine_exprs(self, callback): - """Visits all affine expressions reachable by the expression.""" - pass + def visit_tensor_exprs(self, callback): + """Visits all tensor expression reachable by the expression.""" + callback(self) def _get_all_dim_defs(self) -> Set[DimDef]: """Recursively gets all DimDef affine expressions that are referenced.""" results = set() - def visitor(affine_expr): - if isinstance(affine_expr, DimDef): - results.add(affine_expr) + def visit_dim_def(dim_def): + if isinstance(dim_def, DimDef): + results.add(dim_def) - self.visit_affine_exprs(visitor) + def visit_affine_exprs(expr): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, ReduceApply): + for ind in expr.reduce.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) + + self.visit_tensor_exprs(visit_affine_exprs) return results def collect_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" - pass + def visit_tensor_use(expr): + if isinstance(expr, TensorUse): + uses.add(expr) + self.visit_tensor_exprs(visit_tensor_use) + + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" + def visit_index(expr): + if isinstance(expr, index): + indices.add(expr) + self.visit_tensor_exprs(visit_index) + + def collect_captures(self, captures: Set["CaptureDef"]): + """Collects all CaptureDefs reachable through this expression.""" + def visit_capture_def(expr): + if isinstance(expr, CaptureDef): + captures.add(expr) + self.visit_tensor_exprs(visit_capture_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": return PrimFn.add(self, rhs) @@ -84,13 +109,6 @@ def tensor_name(self) -> str: assert n is not None, "TensorDef not attached" return n - def visit_affine_exprs(self, callback): - for ind in self.indices: - ind.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - uses.add(self) - def __iadd__(self, rhs: TensorExpression) -> TensorExpression: return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) @@ -178,6 +196,35 @@ def __repr__(self): return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " f"shape={self.shape})") +class CaptureDef(TensorExpression): + """Defines an SSA value captured by the operation. + + The captured SSA values are not indexed by the indexing_maps of the + structured op (as opposed to memrefs and tensors). A unique name + identifies the captures and an index determines their position the + operation's parameter list. + """ + + def __init__(self, type_var: TypeVar): + if not isinstance(type_var, TypeVar): + raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.capture_name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"CaptureDef already registered with op: {self}") + self.registered_index = index + self.capture_name = capture_name + self.owner = owner + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarCapture(self.capture_name).expr() + + def __repr__(self): + return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})") class Comprehension: """Represents a single comprehension.""" @@ -279,17 +326,52 @@ def to_scalar_expression(self) -> ScalarExpression: *[arg.to_scalar_expression() for arg in self.args ]).expr() - def visit_affine_exprs(self, callback): - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) for arg in self.args: - arg.collect_uses(uses) + arg.visit_tensor_exprs(callback) def __repr__(self): return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" +class const(TensorExpression): + """Returns the given constant floating point or integer value.""" + + def __init__(self, type_var: TypeVar, value: Any): + if not isinstance(type_var, TypeVar): + raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}") + if not (isinstance(value, float) or isinstance(value, int)): + raise ValueError(f"const requires int or float. Got: {type(value)}") + self.type_var = type_var + self.value = value + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.type_var, self.value).expr() + + def __repr__(self): + return f"const({self.type_var}, {self.value})" + +class index(TensorExpression): + """Returns the iteration index for a given dimension name. + + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ + + def __init__(self, dim : DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + class cast(TensorExpression): """Casts the element type to a type (typically symbolic TypeVar).""" @@ -302,11 +384,9 @@ def to_scalar_expression(self) -> ScalarExpression: return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression()).expr() - def visit_affine_exprs(self, callback): - self.operand.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): - self.operand.collect_uses(uses) + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) + self.operand.visit_tensor_exprs(callback) def __repr__(self): return f"cast({self.to_type}, {repr(self.operand)})" @@ -331,15 +411,9 @@ def to_scalar_expression(self) -> ScalarExpression: ] + [arg.to_scalar_expression() for arg in self.args] return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() - def visit_affine_exprs(self, callback): - for ind in self.reduce.reduce_dims: - ind.visit_affine_exprs(callback) - for arg in self.args: - arg.visit_affine_exprs(callback) - - def collect_uses(self, uses: Set["TensorUse"]): + def visit_tensor_exprs(self, callback): for arg in self.args: - arg.collect_uses(uses) + arg.visit_tensor_exprs(callback) def __repr__(self): return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" @@ -385,6 +459,7 @@ def __init__(self, doc: Optional[str] = None): self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_tensors = dict() # type: Dict[str, TensorDef] + self.registered_captures = dict() # type: Dict[str, CaptureDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @@ -404,12 +479,13 @@ def add_tensor(self, tensor_name: str, tensor: TensorDef): tensor.attach(len(self.registered_tensors), tensor_name, self) self.registered_tensors[tensor_name] = tensor - def tensor(self, name): - """Gets a registered tensor by name.""" - try: - return self.registered_tensors[name] - except KeyError: - raise KeyError(f"Tensor {name} is not registered") + def add_capture(self, capture_name: str, capture: CaptureDef): + """Registers a capture.""" + if capture_name in self.registered_captures: + raise ValueError(f"Capture {capture_name} is already registered " + f"to {self.registered_captures['capture_name']}") + capture.attach(len(self.registered_captures), capture_name, self) + self.registered_captures[capture_name] = capture def __repr__(self): lines = [ @@ -417,6 +493,8 @@ def __repr__(self): ] for name, tensor in self.registered_tensors.items(): lines.append(f" {tensor}") + for name, capture in self.registered_captures.items(): + lines.append(f" {capture}") if self.comprehensions: lines[-1] += " {" for comprehension in self.comprehensions: diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py index fdc6cfd9b..a67d18cc3 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py @@ -70,6 +70,22 @@ def to_yaml_custom_dict(self): def __repr__(self): return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" +class CaptureDefConfig(YAMLObject): + """Wrapper around a CaptureDef.""" + yaml_tag = "LinalgCaptureDef" + + def __init__(self, capture_def: CaptureDef): + self.capture_def = capture_def + + def to_yaml_custom_dict(self): + return dict( + name=self.capture_def.capture_name, + type_var=self.capture_def.type_var.name, + ) + + def __repr__(self): + return f"Def({self.capture_def})" + class LinalgIndexingMapsConfig(YAMLObject): """Abstracts the style of indexing maps that the op exports. @@ -109,10 +125,14 @@ def __init__(self, self.affine_state = AffineBuildState() self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] + self.capture_args = dict() # type: Dict[CaptureDef, CaptureDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - # Compute the ordered set of writes. + # Compute the ordered set of writes and collect the tensor, capture, and + # index uses. collected_uses = set() + collected_captures = set() + collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): self.writes.append((write_use, read_use)) @@ -120,10 +140,14 @@ def __init__(self, for write_use, read_use in self.writes: collected_uses.add(write_use) read_use.collect_uses(collected_uses) + read_use.collect_captures(collected_captures) + read_use.collect_indices(collected_indices) # Need to add all definitions before uses, so process twice. for use in collected_uses: self.add_tensor_arg(use.tensor_def) + for capture in collected_captures: + self.add_capture_arg(capture) for use in collected_uses: self.add_use(use) @@ -170,6 +194,14 @@ def __init__(self, f"dims. Got: {all_reduction_dims}") self.reduction_dims = next(iter(all_reduction_dims)) + # Check the index dimension exists and resolve + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim.dimname} is not part of the iteration " + f"domain {self.affine_state.all_dims}") + index.resolve_dimension_name(self.affine_state) + # Generate the scalar assignments (used to build a body). self.assignments = [ ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) @@ -186,6 +218,11 @@ def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: return sorted(self.uses.values(), key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) + @property + def ordered_capture_args(self) -> Sequence[CaptureDefConfig]: + return sorted(self.capture_args.values(), + key=lambda cdc: cdc.capture_def.registered_index) + @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: """Gets the ordered list of dim bindings (symbolic name, position). @@ -245,6 +282,12 @@ def add_use(self, tensor_use: TensorUse): use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config + def add_capture_arg(self, capture_def: CaptureDef): + if capture_def in self.capture_args: + return + def_config = CaptureDefConfig(capture_def) + self.capture_args[capture_def] = def_config + def _normalize_affine_map(self, affine_map: _ir.AffineMap, with_dims: bool = True) -> _ir.AffineMap: @@ -258,6 +301,7 @@ def _normalize_affine_map(self, def to_yaml_custom_dict(self): self_dict = dict( args=self.ordered_tensor_args, + captures=self.ordered_capture_args, # TODO: Refactor the hierarchy internally when supporting more # than static (preserving this serialized form). indexing_maps=LinalgIndexingMapsConfig( @@ -272,6 +316,9 @@ def __repr__(self): lines.append("tensor_args=[") for def_config in self.ordered_tensor_args: lines.append(f" {repr(def_config)}") + lines.append("], capture_args=[") + for def_config in self.ordered_capture_args: + lines.append(f" {repr(def_config)}") lines.append("], indexing_maps=[") for m in self.indexing_maps: lines.append(f" {repr(m)}") diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py index 002ae51ba..428eadfe0 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -105,11 +105,15 @@ def linalg_structured_op(dsl_func=None, sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if not isinstance(param_default, TensorDef): + if isinstance(param_default, TensorDef): + tc_model.add_tensor(param_name, param_default) + elif isinstance(param_default, CaptureDef): + tc_model.add_capture(param_name, param_default) + else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...): Found {param_name}: {param_default}") + f"TensorDef(...) or CaptureDef(...): Found {param_name}" + f": {param_default}") dsl_func_args.append(param_default) - tc_model.add_tensor(param_name, param_default) # Invoke the DSL func to finish populating the model. with bind_op_def(tc_model): diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py index 682f19138..4a037025d 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, Sequence +from typing import Any, Dict, Sequence from mlir.ir import * from mlir.dialects import linalg @@ -28,10 +28,20 @@ def isa(cls : Type, ty : Type): def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Value): + outs: Sequence[Value], + captures: Sequence[Value]): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + capture_arg_defs = op_config.ordered_capture_args + + # Verify outs and captures are sequences. + if not isinstance(outs, Sequence): + raise ValueError(f"Expected named argument outs to have type Sequence " + f"but got {type(outs)}") + if not isinstance(captures, Sequence): + raise ValueError(f"Expected named argument captures to have type Sequence " + f"but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -40,19 +50,35 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, if outs and len(outs) != len(out_arg_defs): raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") + if captures and len(captures) != len(capture_arg_defs): + raise ValueError(f"Expected {len(capture_arg_defs)} captures but got " + f"{len(captures)} for {op_config}") outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) result_types = [t for t in out_types if isa(RankedTensorType, t)] - # Extract type vars for input/output based types. + # Initialize the type dictionary with the predefined types. type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. for arg_def, arg_element_type in zip( in_arg_defs + out_arg_defs, _get_shaped_element_types_from_values(*ins, *outs)): - tv_name = arg_def.tensor_def.type_var.name - type_mapping[tv_name] = arg_element_type + _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type, + type_mapping) + + # Extract type vars for captures and compute capture argument mapping. + capture_arg_mapping = dict() # type: Dict[str, Value] + for arg_def, capture_value in zip(capture_arg_defs, captures): + _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type, + type_mapping) + capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value # Emit the generic op. # TODO: Support emission of pure memref form. @@ -63,21 +89,22 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - sparse_attr = ArrayAttr.get( - [BoolAttr.get(False) for s in list(ins) + list(outs) if isa(RankedTensorType, s.type)]) - if len(sparse_attr) == 0: - sparse_attr = None + # TODO: Add support for sparse operands once there is a stable interface. + sparse_attr = None return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr) + type_mapping, capture_arg_mapping, indexing_maps_attr, + iterator_types_attr, sparse_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + outs: Sequence[Value] = (), + captures: Sequence[Value] = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs, + captures=captures) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -95,7 +122,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, + capture_arg_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) @@ -110,10 +138,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, op_class_name: str, *ins: Value, - outs: Value = ()): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, \ - type_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + outs: Sequence[Value] = (), + captures: Sequence[Value] = ()): + all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + prepare_common_structured_op(op_config, *ins, outs = outs, + captures = captures) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -127,7 +157,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, linalgDialect = ctx.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, named_op.operation) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps - # attribute that the non-yaml path does not. The non-yaml path hardcodes the + # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. @@ -141,10 +171,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" - def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value]): + def __init__(self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + capture_arg_mapping: Dict[str, Value]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping + self.capture_arg_mapping = capture_arg_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -161,6 +194,16 @@ def expression(self, expr: ScalarExpression) -> Value: except KeyError: raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " f"this structured op.") + elif expr.scalar_capture: + try: + return self.capture_arg_mapping[expr.scalar_capture.capture] + except KeyError: + raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for " + f"this structured op.") + elif expr.scalar_const: + return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value) + elif expr.scalar_index: + return self.index(expr.scalar_index.dim) elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -177,6 +220,25 @@ def expression(self, expr: ScalarExpression) -> Value: return self.cast(expr.symbolic_cast.to_type.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + def constant(self, type_var_name: str, value: Any) -> Value: + try: + type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError(f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mappings.keys()}") + try: + if(_is_floating_point_type(type)): + return std.ConstantOp(type, FloatAttr.get(type, float(value))).result + elif(_is_integer_type(type)): + return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result + except ValueError: + raise ValueError(f"Unable to cast value {value} to type {type}") + raise NotImplementedError(f"Unimplemented constant type {type}") + + def index(self, dim: int) -> Value: + dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim) + return linalg.IndexOp(IndexType.get(), dim_attr).result + def cast(self, type_var_name: str, operand: Value) -> Value: try: to_type = self.type_mapping[type_var_name] @@ -189,15 +251,13 @@ def cast(self, type_var_name: str, operand: Value) -> Value: return self._cast_to_integer(to_type, operand) elif _is_floating_point_type(to_type): return self._cast_to_floating_point(to_type, operand) - - raise ValueError(f"Unable to cast body expression from {operand.type} to " - f"{to_type}") - def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type if _is_floating_point_type(operand_type): return std.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return std.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: @@ -234,14 +294,21 @@ def yield_outputs(self, *output_names: str): def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.AddFOp(lhs.type, lhs, rhs).result - if _is_integer_type(lhs.type): + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return std.AddIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") + def _eval_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.SubFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.SubIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operand: {lhs}") + def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.MulFOp(lhs.type, lhs, rhs).result - if _is_integer_type(lhs.type): + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return std.MulIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") @@ -281,6 +348,12 @@ def _get_tensor_def_names( *tensor_def_configs: TensorDefConfig) -> Sequence[str]: return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] +def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): + if name in type_mapping: + if type_mapping[name] != type: + raise ValueError(f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {type}") + type_mapping[name] = type def _is_floating_point_type(t: Type) -> bool: # TODO: Create a FloatType in the Python API and implement the switch @@ -288,10 +361,11 @@ def _is_floating_point_type(t: Type) -> bool: return (F64Type.isinstance(t) or F32Type.isinstance(t) or F16Type.isinstance(t) or BF16Type.isinstance(t)) - def _is_integer_type(t: Type) -> bool: return IntegerType.isinstance(t) +def _is_index_type(t: Type) -> bool: + return IndexType.isinstance(t) def _get_floating_point_width(t: Type) -> int: # TODO: Create a FloatType in the Python API and implement the switch diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index 9ebf7a9a0..bb1938d71 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -13,7 +13,7 @@ can be easily consumed from the C++ side, not necessarily for ergonomics. """ -from typing import Optional, Sequence +from typing import Any, Optional, Sequence from .yaml_helper import * from .types import * @@ -22,6 +22,9 @@ "ScalarAssign", "ScalarApplyFn", "ScalarArg", + "ScalarCapture", + "ScalarConst", + "ScalarIndex", "ScalarExpression", "ScalarSymbolicCast", ] @@ -53,6 +56,42 @@ def expr(self) -> "ScalarExpression": def __repr__(self): return f"(ScalarArg({self.arg})" +class ScalarCapture: + """A type of ScalarExpression that references a named capture.""" + + def __init__(self, capture: str): + self.capture = capture + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_capture=self) + + def __repr__(self): + return f"(ScalarCapture({self.capture})" + +class ScalarConst: + """A type of ScalarExpression representing a constant.""" + + def __init__(self, type_var: TypeVar, value: Any): + self.type_var = type_var + self.value = value + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) + + def __repr__(self): + return f"(ScalarConst({self.type_var}, {self.value})" + +class ScalarIndex: + """A type of ScalarExpression accessing an iteration index.""" + + def __init__(self, dim : int): + self.dim = dim + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) + + def __repr__(self): + return f"(ScalarIndex({self.dim})" class ScalarSymbolicCast: """A type of ScalarExpression that symbolically casts an operand to a TypeVar. @@ -75,6 +114,9 @@ class ScalarExpression(YAMLObject): Can be one of: - ScalarApplyFn - ScalarArg + - ScalarCapture + - ScalarConst + - ScalarIndex - ScalarSymbolicCast """ yaml_tag = "!ScalarExpression" @@ -82,13 +124,20 @@ class ScalarExpression(YAMLObject): def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, scalar_arg: Optional[ScalarArg] = None, + scalar_capture: Optional[ScalarCapture] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(symbolic_cast)) != 1: + if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) + + bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1: raise ValueError( - "One of 'scalar_apply', 'scalar_block_arg', 'symbolic_cast' must be " - "specified") + "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', " + "'scalar_index', 'symbolic_cast' must be specified") self.scalar_apply = scalar_apply self.scalar_arg = scalar_arg + self.scalar_capture = scalar_capture + self.scalar_const = scalar_const + self.scalar_index = scalar_index self.symbolic_cast = symbolic_cast def to_yaml_custom_dict(self): @@ -99,6 +148,13 @@ def to_yaml_custom_dict(self): )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_capture: + return dict(scalar_capture=self.scalar_capture.capture) + elif self.scalar_const: + return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name, + attributes=[self.scalar_const.value])) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) elif self.symbolic_cast: # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py index 35bbfe712..ddac87287 100644 --- a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py +++ b/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py @@ -22,6 +22,12 @@ "TypeVar", "TV", + # Predefined types. + "I32", + "I64", + "F32", + "F64", + # TypeVar aliases. "T", "U", @@ -63,6 +69,12 @@ def __getattr__(self, n): # Expando access via TV.foo TV = TypeVar.create_expando() +# Predefined types. +I32 = TV.I32 +I64 = TV.I64 +F32 = TV.F32 +F64 = TV.F64 + # Some common type name aliases. T = TV.T U = TV.U From fb3c782a85432264b6c9cd6398b7029919851962 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 22 Apr 2021 17:32:10 +0200 Subject: [PATCH 034/915] [mlir] Split out Python bindings entry point into a separate file This will allow the bindings to be built as a library and reused in out-of-tree projects that want to provide bindings on top of MLIR bindings. Reviewed By: stellaraccident, mikeurbach Differential Revision: https://reviews.llvm.org/D101075 --- mlir/lib/Bindings/Python/CMakeLists.txt | 1 + mlir/lib/Bindings/Python/IRModule.cpp | 146 ++++++++++++++++++++++++ mlir/lib/Bindings/Python/MainModule.cpp | 129 --------------------- 3 files changed, 147 insertions(+), 129 deletions(-) create mode 100644 mlir/lib/Bindings/Python/IRModule.cpp diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index bbccea63c..580405f09 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -84,6 +84,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir IRAffine.cpp IRAttributes.cpp IRCore.cpp + IRModule.cpp IRTypes.cpp PybindUtils.cpp Pass.cpp diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp new file mode 100644 index 000000000..08ce06da8 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -0,0 +1,146 @@ +//===- IRModule.cpp - IR pybind module ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "Globals.h" +#include "PybindUtils.h" + +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; + +// ----------------------------------------------------------------------------- +// PyGlobals +// ----------------------------------------------------------------------------- + +PyGlobals *PyGlobals::instance = nullptr; + +PyGlobals::PyGlobals() { + assert(!instance && "PyGlobals already constructed"); + instance = this; +} + +PyGlobals::~PyGlobals() { instance = nullptr; } + +void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { + py::gil_scoped_acquire(); + if (loadedDialectModulesCache.contains(dialectNamespace)) + return; + // Since re-entrancy is possible, make a copy of the search prefixes. + std::vector localSearchPrefixes = dialectSearchPrefixes; + py::object loaded; + for (std::string moduleName : localSearchPrefixes) { + moduleName.push_back('.'); + moduleName.append(dialectNamespace.data(), dialectNamespace.size()); + + try { + py::gil_scoped_release(); + loaded = py::module::import(moduleName.c_str()); + } catch (py::error_already_set &e) { + if (e.matches(PyExc_ModuleNotFoundError)) { + continue; + } else { + throw; + } + } + break; + } + + // Note: Iterator cannot be shared from prior to loading, since re-entrancy + // may have occurred, which may do anything. + loadedDialectModulesCache.insert(dialectNamespace); +} + +void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, + py::object pyClass) { + py::gil_scoped_acquire(); + py::object &found = dialectClassMap[dialectNamespace]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + + dialectNamespace + + "' is already registered."); + } + found = std::move(pyClass); +} + +void PyGlobals::registerOperationImpl(const std::string &operationName, + py::object pyClass, + py::object rawOpViewClass) { + py::gil_scoped_acquire(); + py::object &found = operationClassMap[operationName]; + if (found) { + throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + + operationName + + "' is already registered."); + } + found = std::move(pyClass); + rawOpViewClassMap[operationName] = std::move(rawOpViewClass); +} + +llvm::Optional +PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { + py::gil_scoped_acquire(); + loadDialectModule(dialectNamespace); + // Fast match against the class map first (common case). + const auto foundIt = dialectClassMap.find(dialectNamespace); + if (foundIt != dialectClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + dialectClassMap[dialectNamespace] = py::none(); + return llvm::None; +} + +llvm::Optional +PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMapCache.find(operationName); + if (foundIt != rawOpViewClassMapCache.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + auto split = operationName.split('.'); + llvm::StringRef dialectNamespace = split.first; + loadDialectModule(dialectNamespace); + + // Attempt to find from the canonical map and cache. + { + py::gil_scoped_acquire(); + auto foundIt = rawOpViewClassMap.find(operationName); + if (foundIt != rawOpViewClassMap.end()) { + if (foundIt->second.is_none()) + return llvm::None; + assert(foundIt->second && "py::object is defined"); + // Positive cache. + rawOpViewClassMapCache[operationName] = foundIt->second; + return foundIt->second; + } else { + // Negative cache. + rawOpViewClassMap[operationName] = py::none(); + return llvm::None; + } + } +} + +void PyGlobals::clearImportCache() { + py::gil_scoped_acquire(); + loadedDialectModulesCache.clear(); + rawOpViewClassMapCache.clear(); +} diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 79128f267..60c282d1d 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -20,135 +20,6 @@ namespace py = pybind11; using namespace mlir; using namespace mlir::python; -// ----------------------------------------------------------------------------- -// PyGlobals -// ----------------------------------------------------------------------------- - -PyGlobals *PyGlobals::instance = nullptr; - -PyGlobals::PyGlobals() { - assert(!instance && "PyGlobals already constructed"); - instance = this; -} - -PyGlobals::~PyGlobals() { instance = nullptr; } - -void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - py::gil_scoped_acquire(); - if (loadedDialectModulesCache.contains(dialectNamespace)) - return; - // Since re-entrancy is possible, make a copy of the search prefixes. - std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded; - for (std::string moduleName : localSearchPrefixes) { - moduleName.push_back('.'); - moduleName.append(dialectNamespace.data(), dialectNamespace.size()); - - try { - py::gil_scoped_release(); - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { - if (e.matches(PyExc_ModuleNotFoundError)) { - continue; - } else { - throw; - } - } - break; - } - - // Note: Iterator cannot be shared from prior to loading, since re-entrancy - // may have occurred, which may do anything. - loadedDialectModulesCache.insert(dialectNamespace); -} - -void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::gil_scoped_acquire(); - py::object &found = dialectClassMap[dialectNamespace]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); - } - found = std::move(pyClass); -} - -void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { - py::gil_scoped_acquire(); - py::object &found = operationClassMap[operationName]; - if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); - } - found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); -} - -llvm::Optional -PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - py::gil_scoped_acquire(); - loadDialectModule(dialectNamespace); - // Fast match against the class map first (common case). - const auto foundIt = dialectClassMap.find(dialectNamespace); - if (foundIt != dialectClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - - // Not found and loading did not yield a registration. Negative cache. - dialectClassMap[dialectNamespace] = py::none(); - return llvm::None; -} - -llvm::Optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. - auto split = operationName.split('.'); - llvm::StringRef dialectNamespace = split.first; - loadDialectModule(dialectNamespace); - - // Attempt to find from the canonical map and cache. - { - py::gil_scoped_acquire(); - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { - if (foundIt->second.is_none()) - return llvm::None; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; - return foundIt->second; - } else { - // Negative cache. - rawOpViewClassMap[operationName] = py::none(); - return llvm::None; - } - } -} - -void PyGlobals::clearImportCache() { - py::gil_scoped_acquire(); - loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); -} - // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- From 6f61818d950d8afc638ada90a4cb668160650ca2 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 2 May 2021 15:15:21 -0700 Subject: [PATCH 035/915] [mlir][Python] Add casting constructor to Type and Attribute. * This makes them consistent with custom types/attributes, whose constructors will do a type checked conversion. Of course, the base classes can represent everything so never error. * More importantly, this makes it possible to subclass Type and Attribute out of tree in sensible ways. Differential Revision: https://reviews.llvm.org/D101734 --- mlir/lib/Bindings/Python/IRCore.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 160e35b21..d11edb1c6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2255,6 +2255,10 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyAttribute. //---------------------------------------------------------------------------- py::class_(m, "Attribute") + // Delegate to the PyAttribute copy constructor, which will also lifetime + // extend the backing context which owns the MlirAttribute. + .def(py::init(), py::arg("cast_from_type"), + "Casts the passed attribute to the generic Attribute") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) @@ -2358,6 +2362,10 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyType. //---------------------------------------------------------------------------- py::class_(m, "Type") + // Delegate to the PyType copy constructor, which will also lifetime + // extend the backing context which owns the MlirType. + .def(py::init(), py::arg("cast_from_type"), + "Casts the passed type to the generic Type") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( From 9ef83bdd0806f1a421ee84f0e7c85bcd97b4fa8e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 28 Apr 2021 20:04:17 +0000 Subject: [PATCH 036/915] Move MLIR python sources to mlir/python. * NFC but has some fixes for CMake glitches discovered along the way (things not cleaning properly, co-mingled depends). * Includes previously unsubmitted fix in D98681 and a TODO to fix it more appropriately in a smaller followup. Differential Revision: https://reviews.llvm.org/D101493 --- mlir/lib/Bindings/Python/CMakeLists.txt | 109 +----------------- .../Bindings/Python => python}/.style.yapf | 0 mlir/python/CMakeLists.txt | 49 ++++++++ .../Python => python}/mlir/_cext_loader.py | 0 .../Python => python}/mlir/_dlloader.py | 0 .../mlir/conversions/__init__.py | 0 .../mlir/dialects}/AsyncOps.td | 0 .../mlir/dialects}/BuiltinOps.td | 0 mlir/python/mlir/dialects/CMakeLists.txt | 71 ++++++++++++ .../Python => python/mlir/dialects}/GPUOps.td | 0 .../mlir/dialects}/LinalgOps.td | 0 .../mlir/dialects}/MemRefOps.td | 0 .../mlir/dialects}/ShapeOps.td | 0 .../mlir/dialects}/StandardOps.td | 0 .../mlir/dialects}/TensorOps.td | 0 .../mlir/dialects/_builtin_ops_ext.py | 0 .../mlir/dialects/_linalg_ops_ext.py | 0 .../mlir/dialects/_ods_common.py | 0 .../mlir/dialects/async_dialect/__init__.py | 0 .../dialects/async_dialect/passes/__init__.py | 0 .../mlir/dialects/builtin.py | 0 .../mlir/dialects/gpu/__init__.py | 0 .../mlir/dialects/gpu/passes/__init__.py | 0 .../mlir/dialects/linalg/__init__.py | 0 .../mlir/dialects/linalg/opdsl/__init__.py | 0 .../mlir/dialects/linalg/opdsl/dump_oplib.py | 0 .../dialects/linalg/opdsl/lang/__init__.py | 0 .../mlir/dialects/linalg/opdsl/lang/affine.py | 0 .../linalg/opdsl/lang/comprehension.py | 0 .../mlir/dialects/linalg/opdsl/lang/config.py | 0 .../mlir/dialects/linalg/opdsl/lang/dsl.py | 0 .../dialects/linalg/opdsl/lang/emitter.py | 0 .../dialects/linalg/opdsl/lang/scalar_expr.py | 0 .../mlir/dialects/linalg/opdsl/lang/types.py | 0 .../dialects/linalg/opdsl/lang/yaml_helper.py | 0 .../dialects/linalg/opdsl/ops/__init__.py | 0 .../linalg/opdsl/ops/core_named_ops.py | 0 .../mlir/dialects/linalg/passes/__init__.py | 0 .../Python => python}/mlir/dialects/memref.py | 0 .../mlir/dialects/python_test.py | 0 .../Python => python}/mlir/dialects/shape.py | 0 .../Python => python}/mlir/dialects/std.py | 0 .../Python => python}/mlir/dialects/tensor.py | 0 .../mlir/execution_engine.py | 0 .../Bindings/Python => python}/mlir/ir.py | 0 .../Python => python}/mlir/passmanager.py | 0 .../mlir/runtime/__init__.py | 0 .../mlir/runtime/np_to_memref.py | 0 .../mlir/transforms/__init__.py | 0 .../Python => python}/requirements.txt | 0 50 files changed, 121 insertions(+), 108 deletions(-) rename mlir/{lib/Bindings/Python => python}/.style.yapf (100%) create mode 100644 mlir/python/CMakeLists.txt rename mlir/{lib/Bindings/Python => python}/mlir/_cext_loader.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/_dlloader.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/conversions/__init__.py (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/AsyncOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/BuiltinOps.td (100%) create mode 100644 mlir/python/mlir/dialects/CMakeLists.txt rename mlir/{lib/Bindings/Python => python/mlir/dialects}/GPUOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/LinalgOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/MemRefOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/ShapeOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/StandardOps.td (100%) rename mlir/{lib/Bindings/Python => python/mlir/dialects}/TensorOps.td (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/_builtin_ops_ext.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/_linalg_ops_ext.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/_ods_common.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/async_dialect/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/async_dialect/passes/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/builtin.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/gpu/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/gpu/passes/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/dump_oplib.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/affine.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/comprehension.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/config.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/dsl.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/emitter.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/scalar_expr.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/types.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/lang/yaml_helper.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/ops/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/linalg/passes/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/memref.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/python_test.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/shape.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/std.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/dialects/tensor.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/execution_engine.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/ir.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/passmanager.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/runtime/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/runtime/np_to_memref.py (100%) rename mlir/{lib/Bindings/Python => python}/mlir/transforms/__init__.py (100%) rename mlir/{lib/Bindings/Python => python}/requirements.txt (100%) diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 580405f09..a2e972dc1 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -1,77 +1,6 @@ -include(AddMLIRPythonExtension) +include(AddMLIRPython) add_custom_target(MLIRBindingsPythonExtension) -################################################################################ -# Copy python source tree. -################################################################################ - -file(GLOB_RECURSE PY_SRC_FILES - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "${CMAKE_CURRENT_SOURCE_DIR}/mlir/*.py") - -add_custom_target(MLIRBindingsPythonSources ALL - DEPENDS ${PY_SRC_FILES} -) -add_dependencies(MLIRBindingsPythonExtension MLIRBindingsPythonSources) - -foreach(PY_SRC_FILE ${PY_SRC_FILES}) - set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") - get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY) - file(MAKE_DIRECTORY "${PY_DEST_DIR}") - add_custom_command( - TARGET MLIRBindingsPythonSources PRE_BUILD - COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}" - DEPENDS "${PY_SRC_FILE}" - COMMAND "${CMAKE_COMMAND}" -E create_symlink - "${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}" - ) -endforeach() - -################################################################################ -# Generate dialect-specific bindings. -################################################################################ - -add_mlir_dialect_python_bindings(MLIRBindingsPythonAsyncOps - TD_FILE AsyncOps.td - DIALECT_NAME async_dialect) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonAsyncOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps - TD_FILE BuiltinOps.td - DIALECT_NAME builtin) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonGPUOps - TD_FILE GPUOps.td - DIALECT_NAME gpu) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonGPUOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps - TD_FILE LinalgOps.td - DIALECT_NAME linalg - DEPENDS LinalgOdsGen) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps - TD_FILE MemRefOps.td - DIALECT_NAME memref) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMemRefOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonShapeOps - TD_FILE ShapeOps.td - DIALECT_NAME shape) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonShapeOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps - TD_FILE StandardOps.td - DIALECT_NAME std) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonStandardOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps - TD_FILE TensorOps.td - DIALECT_NAME tensor) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) - ################################################################################ # Build core python extension ################################################################################ @@ -92,42 +21,6 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir ) add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) -# Note that we copy from the source tree just like for headers because -# it will not be polluted with py_cache runtime artifacts (from testing and -# such). -install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlir - DESTINATION python - COMPONENT MLIRBindingsPythonSources - FILES_MATCHING PATTERN "*.py" -) - -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonSources - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonSources) -endif() - -# Dialect sources are generated. Install separately. -# Note that __pycache__ directories may have been left by tests and other -# executions. And __init__.py is handled as a regular source file. -install( - DIRECTORY ${PROJECT_BINARY_DIR}/python/mlir/dialects - DESTINATION python/mlir - COMPONENT MLIRBindingsPythonDialects - FILES_MATCHING PATTERN "*.py" - PATTERN "__pycache__" EXCLUDE - PATTERN "__init__.py" EXCLUDE -) - -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonDialects - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonDialects) -endif() - add_subdirectory(Transforms) add_subdirectory(Conversions) diff --git a/mlir/lib/Bindings/Python/.style.yapf b/mlir/python/.style.yapf similarity index 100% rename from mlir/lib/Bindings/Python/.style.yapf rename to mlir/python/.style.yapf diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt new file mode 100644 index 000000000..1b9705a03 --- /dev/null +++ b/mlir/python/CMakeLists.txt @@ -0,0 +1,49 @@ +################################################################################ +# Copy python source tree. +################################################################################ + +file(GLOB_RECURSE PY_SRC_FILES + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "${CMAKE_CURRENT_SOURCE_DIR}/mlir/*.py") + +add_custom_target(MLIRBindingsPythonSources ALL + DEPENDS + ${PY_SRC_FILES} +) + +foreach(PY_SRC_FILE ${PY_SRC_FILES}) + set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") + get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY) + file(MAKE_DIRECTORY "${PY_DEST_DIR}") + add_custom_command( + TARGET MLIRBindingsPythonSources PRE_BUILD + COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}" + DEPENDS "${PY_SRC_FILE}" + BYPRODUCTS "${PY_DEST_FILE}" + COMMAND "${CMAKE_COMMAND}" -E create_symlink + "${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}" + ) +endforeach() + +# Note that we copy from the source tree just like for headers because +# it will not be polluted with py_cache runtime artifacts (from testing and +# such). +install( + DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlir + DESTINATION python + COMPONENT MLIRBindingsPythonSources + FILES_MATCHING PATTERN "*.py" +) + +if (NOT LLVM_ENABLE_IDE) + add_llvm_install_targets( + install-MLIRBindingsPythonSources + DEPENDS MLIRBindingsPythonSources + COMPONENT MLIRBindingsPythonSources) +endif() + +################################################################################ +# Generated sources. +################################################################################ + +add_subdirectory(mlir/dialects) diff --git a/mlir/lib/Bindings/Python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/_cext_loader.py rename to mlir/python/mlir/_cext_loader.py diff --git a/mlir/lib/Bindings/Python/mlir/_dlloader.py b/mlir/python/mlir/_dlloader.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/_dlloader.py rename to mlir/python/mlir/_dlloader.py diff --git a/mlir/lib/Bindings/Python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/conversions/__init__.py rename to mlir/python/mlir/conversions/__init__.py diff --git a/mlir/lib/Bindings/Python/AsyncOps.td b/mlir/python/mlir/dialects/AsyncOps.td similarity index 100% rename from mlir/lib/Bindings/Python/AsyncOps.td rename to mlir/python/mlir/dialects/AsyncOps.td diff --git a/mlir/lib/Bindings/Python/BuiltinOps.td b/mlir/python/mlir/dialects/BuiltinOps.td similarity index 100% rename from mlir/lib/Bindings/Python/BuiltinOps.td rename to mlir/python/mlir/dialects/BuiltinOps.td diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt new file mode 100644 index 000000000..31a4ee55b --- /dev/null +++ b/mlir/python/mlir/dialects/CMakeLists.txt @@ -0,0 +1,71 @@ +include(AddMLIRPython) + +################################################################################ +# Generate dialect-specific bindings. +################################################################################ + +add_mlir_dialect_python_bindings(MLIRBindingsPythonAsyncOps + TD_FILE AsyncOps.td + DIALECT_NAME async_dialect) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonAsyncOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps + TD_FILE BuiltinOps.td + DIALECT_NAME builtin) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonGPUOps + TD_FILE GPUOps.td + DIALECT_NAME gpu) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonGPUOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps + TD_FILE LinalgOps.td + DIALECT_NAME linalg + DEPENDS LinalgOdsGen) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps + TD_FILE MemRefOps.td + DIALECT_NAME memref) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMemRefOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonShapeOps + TD_FILE ShapeOps.td + DIALECT_NAME shape) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonShapeOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps + TD_FILE StandardOps.td + DIALECT_NAME std) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonStandardOps) + +add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps + TD_FILE TensorOps.td + DIALECT_NAME tensor) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) + +################################################################################ +# Installation. +################################################################################ + +# Dialect sources are generated. Install separately. +# Note that __pycache__ directories may have been left by tests and other +# executions. And __init__.py is handled as a regular source file. +# TODO: Eliminate this glob install, instead adding INSTALL_COMPONENT to +# add_mlir_dialect_python_bindings and installing the precise file there. +install( + DIRECTORY ${PROJECT_BINARY_DIR}/python/mlir/dialects + DESTINATION python/mlir + COMPONENT MLIRBindingsPythonDialects + FILES_MATCHING PATTERN "_*_gen.py" + PATTERN "__pycache__" EXCLUDE + PATTERN "__init__.py" EXCLUDE +) + +if (NOT LLVM_ENABLE_IDE) + add_llvm_install_targets( + install-MLIRBindingsPythonDialects + DEPENDS MLIRBindingsPythonSources + COMPONENT MLIRBindingsPythonDialects) +endif() diff --git a/mlir/lib/Bindings/Python/GPUOps.td b/mlir/python/mlir/dialects/GPUOps.td similarity index 100% rename from mlir/lib/Bindings/Python/GPUOps.td rename to mlir/python/mlir/dialects/GPUOps.td diff --git a/mlir/lib/Bindings/Python/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td similarity index 100% rename from mlir/lib/Bindings/Python/LinalgOps.td rename to mlir/python/mlir/dialects/LinalgOps.td diff --git a/mlir/lib/Bindings/Python/MemRefOps.td b/mlir/python/mlir/dialects/MemRefOps.td similarity index 100% rename from mlir/lib/Bindings/Python/MemRefOps.td rename to mlir/python/mlir/dialects/MemRefOps.td diff --git a/mlir/lib/Bindings/Python/ShapeOps.td b/mlir/python/mlir/dialects/ShapeOps.td similarity index 100% rename from mlir/lib/Bindings/Python/ShapeOps.td rename to mlir/python/mlir/dialects/ShapeOps.td diff --git a/mlir/lib/Bindings/Python/StandardOps.td b/mlir/python/mlir/dialects/StandardOps.td similarity index 100% rename from mlir/lib/Bindings/Python/StandardOps.td rename to mlir/python/mlir/dialects/StandardOps.td diff --git a/mlir/lib/Bindings/Python/TensorOps.td b/mlir/python/mlir/dialects/TensorOps.td similarity index 100% rename from mlir/lib/Bindings/Python/TensorOps.td rename to mlir/python/mlir/dialects/TensorOps.td diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/_builtin_ops_ext.py rename to mlir/python/mlir/dialects/_builtin_ops_ext.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/_linalg_ops_ext.py rename to mlir/python/mlir/dialects/_linalg_ops_ext.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/_ods_common.py rename to mlir/python/mlir/dialects/_ods_common.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/async_dialect/__init__.py rename to mlir/python/mlir/dialects/async_dialect/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/async_dialect/passes/__init__.py rename to mlir/python/mlir/dialects/async_dialect/passes/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/builtin.py rename to mlir/python/mlir/dialects/builtin.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/gpu/__init__.py rename to mlir/python/mlir/dialects/gpu/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py b/mlir/python/mlir/dialects/gpu/passes/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/gpu/passes/__init__.py rename to mlir/python/mlir/dialects/gpu/passes/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/__init__.py rename to mlir/python/mlir/dialects/linalg/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/dump_oplib.py rename to mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/affine.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/comprehension.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/config.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/config.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/dsl.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/emitter.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/types.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/types.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py rename to mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/__init__.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/__init__.py rename to mlir/python/mlir/dialects/linalg/opdsl/ops/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py rename to mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py b/mlir/python/mlir/dialects/linalg/passes/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/linalg/passes/__init__.py rename to mlir/python/mlir/dialects/linalg/passes/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/memref.py rename to mlir/python/mlir/dialects/memref.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/python_test.py rename to mlir/python/mlir/dialects/python_test.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/shape.py b/mlir/python/mlir/dialects/shape.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/shape.py rename to mlir/python/mlir/dialects/shape.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/std.py b/mlir/python/mlir/dialects/std.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/std.py rename to mlir/python/mlir/dialects/std.py diff --git a/mlir/lib/Bindings/Python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/dialects/tensor.py rename to mlir/python/mlir/dialects/tensor.py diff --git a/mlir/lib/Bindings/Python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/execution_engine.py rename to mlir/python/mlir/execution_engine.py diff --git a/mlir/lib/Bindings/Python/mlir/ir.py b/mlir/python/mlir/ir.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/ir.py rename to mlir/python/mlir/ir.py diff --git a/mlir/lib/Bindings/Python/mlir/passmanager.py b/mlir/python/mlir/passmanager.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/passmanager.py rename to mlir/python/mlir/passmanager.py diff --git a/mlir/lib/Bindings/Python/mlir/runtime/__init__.py b/mlir/python/mlir/runtime/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/runtime/__init__.py rename to mlir/python/mlir/runtime/__init__.py diff --git a/mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/runtime/np_to_memref.py rename to mlir/python/mlir/runtime/np_to_memref.py diff --git a/mlir/lib/Bindings/Python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py similarity index 100% rename from mlir/lib/Bindings/Python/mlir/transforms/__init__.py rename to mlir/python/mlir/transforms/__init__.py diff --git a/mlir/lib/Bindings/Python/requirements.txt b/mlir/python/requirements.txt similarity index 100% rename from mlir/lib/Bindings/Python/requirements.txt rename to mlir/python/requirements.txt From d9b73eb98ab66c748f88d7872293f289093a9643 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Thu, 6 May 2021 18:24:07 +0200 Subject: [PATCH 037/915] Fix array attribute in bindings for linalg.init_tensor Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D101998 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 4714e69b3..0aea4e603 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -74,9 +74,9 @@ def __init__(self, result_type = RankedTensorType.get(sizes, element_type) static_size_ints = sizes - index_type = IndexType.get(context) + i64_type = IntegerType.get_signless(64) attributes["static_sizes"] = ArrayAttr.get( - [IntegerAttr.get(index_type, s) for s in static_size_ints], + [IntegerAttr.get(i64_type, s) for s in static_size_ints], context=context) op = self.build_generic(results=[result_type], operands=operands, From 545ddbf6090a271618a7141ab8159adea7c2f4d7 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 9 May 2021 16:14:05 -0700 Subject: [PATCH 038/915] [mlir][CAPI] Add CAPI bindings for the sparse_tensor dialect. * Adds dialect registration, hand coded 'encoding' attribute and test. * An MLIR CAPI tablegen backend for attributes does not exist, and this is a relatively complicated case. I opted to hand code it in a canonical way for now, which will provide a reasonable blueprint for building out the tablegen version in the future. * Also added a (local) CMake function for declaring new CAPI tests, since it was getting repetitive/buggy. Differential Revision: https://reviews.llvm.org/D102141 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 77 ++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 30 +++++---- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 71 ++++++++++++++++++++ 3 files changed, 164 insertions(+), 14 deletions(-) create mode 100644 mlir/include/mlir-c/Dialect/SparseTensor.h create mode 100644 mlir/lib/CAPI/Dialect/SparseTensor.cpp diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h new file mode 100644 index 000000000..2615a1655 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -0,0 +1,77 @@ +//===-- mlir-c/Dialect/SparseTensor.h - C API for SparseTensor ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SPARSE_TENSOR_H +#define MLIR_C_DIALECT_SPARSE_TENSOR_H + +#include "mlir-c/AffineMap.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); + +/// Dimension level types that define sparse tensors: +/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE - dimension is dense, every +/// entry is stored +/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED - dimension is sparse, +/// only nonzeros are stored. +/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON - dimension contains single +/// coordinate, no siblings. +/// +/// These correspond to SparseTensorEncodingAttr::DimLevelType in the C++ API. +/// If updating, keep them in sync and update the static_assert in the impl +/// file. +enum MlirSparseTensorDimLevelType { + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE, + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED, + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON, +}; + +//===----------------------------------------------------------------------===// +// SparseTensorEncodingAttr +//===----------------------------------------------------------------------===// + +/// Checks whether the given attribute is a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED bool +mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); + +/// Creates a sparse_tensor.encoding attribute with the given parameters. +MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( + MlirContext ctx, intptr_t numDimLevelTypes, + enum MlirSparseTensorDimLevelType const *dimLevelTypes, + MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth); + +/// Returns the number of dim level types in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED intptr_t +mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr); + +/// Returns a specified dim level type in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos); + +/// Returns the dimension ordering in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr); + +/// Returns the pointer bit width in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED int +mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr); + +/// Returns the index bit width in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED int +mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_SPARSE_TENSOR_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index dd9bd6f67..69371fd57 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,21 +1,8 @@ -# TODO: Make the check source feature optional as an argument on *_add_library. -set(LLVM_OPTIONAL_SOURCES - Async.cpp - AsyncPasses.cpp - GPU.cpp - GPUPasses.cpp - Linalg.cpp - LinalgPasses.cpp - SCF.cpp - Shape.cpp - Standard.cpp - Tensor.cpp -) - add_mlir_public_c_api_library(MLIRCAPIAsync Async.cpp AsyncPasses.cpp + PARTIAL_SOURCES_INTENDED DEPENDS MLIRAsyncPassIncGen @@ -30,6 +17,7 @@ add_mlir_public_c_api_library(MLIRCAPIGPU GPU.cpp GPUPasses.cpp + PARTIAL_SOURCES_INTENDED DEPENDS MLIRGPUPassIncGen @@ -43,6 +31,7 @@ add_mlir_public_c_api_library(MLIRCAPILinalg Linalg.cpp LinalgPasses.cpp + PARTIAL_SOURCES_INTENDED DEPENDS MLIRLinalgPassIncGen @@ -56,6 +45,7 @@ add_mlir_public_c_api_library(MLIRCAPILinalg add_mlir_public_c_api_library(MLIRCAPISCF SCF.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRSCF @@ -64,14 +54,25 @@ add_mlir_public_c_api_library(MLIRCAPISCF add_mlir_public_c_api_library(MLIRCAPIShape Shape.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRShape ) +add_mlir_public_c_api_library(MLIRCAPISparseTensor + SparseTensor.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRSparseTensor +) + add_mlir_public_c_api_library(MLIRCAPIStandard Standard.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRStandard @@ -80,6 +81,7 @@ add_mlir_public_c_api_library(MLIRCAPIStandard add_mlir_public_c_api_library(MLIRCAPITensor Tensor.cpp + PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRTensor diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp new file mode 100644 index 000000000..f35c14af2 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -0,0 +1,71 @@ +//===- Tensor.cpp - C API for SparseTensor dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SparseTensor.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/AffineMap.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Support/LLVM.h" + +using namespace llvm; +using namespace mlir::sparse_tensor; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, + mlir::sparse_tensor::SparseTensorDialect) + +// Ensure the C-API enums are int-castable to C++ equivalents. +static_assert( + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == + static_cast(SparseTensorEncodingAttr::DimLevelType::Dense) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::Compressed) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == + static_cast(SparseTensorEncodingAttr::DimLevelType::Singleton), + "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); + +bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirSparseTensorEncodingAttrGet( + MlirContext ctx, intptr_t numDimLevelTypes, + MlirSparseTensorDimLevelType const *dimLevelTypes, + MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) { + SmallVector cppDimLevelTypes; + cppDimLevelTypes.resize(numDimLevelTypes); + for (intptr_t i = 0; i < numDimLevelTypes; ++i) + cppDimLevelTypes[i] = + static_cast(dimLevelTypes[i]); + return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes, + unwrap(dimOrdering), + pointerBitWidth, indexBitWidth)); +} + +MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { + return wrap(unwrap(attr).cast().getDimOrdering()); +} + +intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { + return unwrap(attr).cast().getDimLevelType().size(); +} + +MlirSparseTensorDimLevelType +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { + return static_cast( + unwrap(attr).cast().getDimLevelType()[pos]); +} + +int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { + return unwrap(attr).cast().getPointerBitWidth(); +} + +int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) { + return unwrap(attr).cast().getIndexBitWidth(); +} From f164f876043ab9721e14b14c7ea002dfe3283f9a Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 9 May 2021 18:09:09 -0700 Subject: [PATCH 039/915] [mlir][Python] Upstream the PybindAdaptors.h helpers and use it to implement sparse_tensor.encoding. * The PybindAdaptors.h file has been evolving across different sub-projects (npcomp, circt) and has been successfully used for out of tree python API interop/extensions and defining custom types. * Since sparse_tensor.encoding is the first in-tree custom attribute we are supporting, it seemed like the right time to upstream this header and use it to define the attribute in a way that we can support for both in-tree and out-of-tree use (prior, I had not wanted to upstream dead code which was not used in-tree). * Adapted the circt version of `mlir_type_subclass`, also providing an `mlir_attribute_subclass`. As we get a bit of mileage on this, I would like to transition the builtin types/attributes to this mechanism and delete the old in-tree only `PyConcreteType` and `PyConcreteAttribute` template helpers (which cannot work reliably out of tree as they depend on internals). * Added support for defaulting the MlirContext if none is passed so that we can support the same idioms as in-tree versions. There is quite a bit going on here and I can split it up if needed, but would prefer to keep the first use and the header together so sending out in one patch. Differential Revision: https://reviews.llvm.org/D102144 --- .../mlir/Bindings/Python/PybindAdaptors.h | 428 ++++++++++++++++++ mlir/lib/Bindings/Python/CMakeLists.txt | 1 + mlir/lib/Bindings/Python/DialectLinalg.cpp | 12 +- mlir/lib/Bindings/Python/DialectLinalg.h | 22 - .../Bindings/Python/DialectSparseTensor.cpp | 74 +++ mlir/lib/Bindings/Python/Dialects.h | 24 + mlir/lib/Bindings/Python/MainModule.cpp | 6 +- 7 files changed, 535 insertions(+), 32 deletions(-) create mode 100644 mlir/include/mlir/Bindings/Python/PybindAdaptors.h delete mode 100644 mlir/lib/Bindings/Python/DialectLinalg.h create mode 100644 mlir/lib/Bindings/Python/DialectSparseTensor.cpp create mode 100644 mlir/lib/Bindings/Python/Dialects.h diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h new file mode 100644 index 000000000..db8769d3c --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -0,0 +1,428 @@ +//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types. The facilities here do not depend on +// implementation details of the MLIR Python API and do not introduce C++-level +// dependencies with it (requiring only Python and CAPI-level dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H +#define MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H + +#include +#include +#include + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/IR.h" + +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/Twine.h" + +namespace py = pybind11; + +// TODO: Move this to Interop.h and make it externally configurable/use it +// consistently to locate the "import mlir" top-level. +#define MLIR_PYTHON_PACKAGE_PREFIX "mlir." + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace pybind11 { +namespace detail { + +template +struct type_caster> : optional_caster> {}; + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). This is intended to be used from +/// type casters, which are invoked with a raw handle (unowned). The returned +/// object's lifetime may not extend beyond the apiObject handle without +/// explicitly having its refcount increased (i.e. on return). +static py::object mlirApiObjectToCapsule(py::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return py::reinterpret_borrow(apiObject); + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + +// Note: Currently all of the following support cast from py::object to the +// Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirAffineMap, _("MlirAffineMap")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(value)) { + return false; + } + return !mlirAffineMapIsNull(value); + } + static handle cast(MlirAffineMap v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonAffineMapToCapsule(v)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirAttribute, _("MlirAttribute")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAttribute(capsule.ptr()); + if (mlirAttributeIsNull(value)) { + return false; + } + return true; + } + static handle cast(MlirAttribute v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonAttributeToCapsule(v)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirContext, _("MlirContext")); + bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Context") + .attr("current"); + } + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule.ptr()); + if (mlirContextIsNull(value)) { + return false; + } + return true; + } +}; + +/// Casts object <-> MlirLocation. +// TODO: Coerce None to default MlirLocation. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToLocation(capsule.ptr()); + if (mlirLocationIsNull(value)) { + return false; + } + return true; + } + static handle cast(MlirLocation v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonLocationToCapsule(v)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirModule, _("MlirModule")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToModule(capsule.ptr()); + if (mlirModuleIsNull(value)) { + return false; + } + return true; + } + static handle cast(MlirModule v, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonModuleToCapsule(v)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirOperation, _("MlirOperation")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToOperation(capsule.ptr()); + if (mlirOperationIsNull(value)) { + return false; + } + return true; + } + static handle cast(MlirOperation v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonOperationToCapsule(v)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirPassManager, _("MlirPassManager")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + if (mlirPassManagerIsNull(value)) { + return false; + } + return true; + } +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirType, _("MlirType")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + if (mlirTypeIsNull(value)) { + return false; + } + return true; + } + static handle cast(MlirType t, return_value_policy, handle) { + py::object capsule = + py::reinterpret_steal(mlirPythonTypeToCapsule(t)); + return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +} // namespace detail +} // namespace pybind11 + +namespace mlir { +namespace python { +namespace adaptors { + +/// Provides a facility like py::class_ for defining a new class in a scope, +/// but this allows extension of an arbitrary Python class, defining methods +/// on it is a similar way. Classes defined in this way are very similar to +/// if defined in Python in the usual way but use Pybind11 machinery to do +/// it. These are not "real" Pybind11 classes but pure Python classes with no +/// relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in pybind11. +class pure_subclass { +public: + pure_subclass(py::handle scope, const char *derivedClassName, + py::object superClass) { + py::object pyType = + py::reinterpret_borrow((PyObject *)&PyType_Type); + py::object metaclass = pyType(superClass); + py::dict attributes; + + thisClass = + metaclass(derivedClassName, py::make_tuple(superClass), attributes); + scope.attr(derivedClassName) = thisClass; + } + + template + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + py::cpp_function cf( + std::forward(f), py::name(name), py::is_method(py::none()), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + thisClass.attr(cf.name()) = cf; + return *this; + } + + template + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + py::cpp_function cf( + std::forward(f), py::name(name), py::is_method(py::none()), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + auto builtinProperty = + py::reinterpret_borrow((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf( + std::forward(f), py::name(name), py::scope(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + thisClass.attr(cf.name()) = py::staticmethod(cf); + return *this; + } + + template + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + py::cpp_function cf( + std::forward(f), py::name(name), py::scope(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); + thisClass.attr(cf.name()) = + py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); + return *this; + } + +protected: + py::object superClass; + py::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(py::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + .attr("Attribute")) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, py::object superClass) + : pure_subclass(scope, typeClassName, superClass) { + // Casting constructor. Note that defining an __init__ method is special + // and not yet generalized on pure_subclass (it requires a somewhat + // different cpp_function and other requirements on chaining to super + // __init__ make it more awkward to do generally). + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function initCf( + [superClass, isaFunction, captureTypeName](py::object self, + py::object otherType) { + MlirAttribute rawAttribute = py::cast(otherType); + if (!isaFunction(rawAttribute)) { + auto origRepr = py::repr(otherType).cast(); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + superClass.attr("__init__")(self, otherType); + }, + py::arg("cast_from_type"), py::is_method(py::none()), + "Casts the passed type to this specific sub-type."); + thisClass.attr("__init__") = initCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + py::arg("other_attribute")); + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir").attr("Type")) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(py::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, py::object superClass) + : pure_subclass(scope, typeClassName, superClass) { + // Casting constructor. Note that defining an __init__ method is special + // and not yet generalized on pure_subclass (it requires a somewhat + // different cpp_function and other requirements on chaining to super + // __init__ make it more awkward to do generally). + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + py::cpp_function initCf( + [superClass, isaFunction, captureTypeName](py::object self, + py::object otherType) { + MlirType rawType = py::cast(otherType); + if (!isaFunction(rawType)) { + auto origRepr = py::repr(otherType).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + superClass.attr("__init__")(self, otherType); + }, + py::arg("cast_from_type"), py::is_method(py::none()), + "Casts the passed type to this specific sub-type."); + thisClass.attr("__init__") = initCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + py::arg("other_type")); + } +}; + +} // namespace adaptors +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index a2e972dc1..7dc1f64b4 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -9,6 +9,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir python SOURCES DialectLinalg.cpp + DialectSparseTensor.cpp MainModule.cpp IRAffine.cpp IRAttributes.cpp diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 849a0039a..dfac96db7 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// +#include "Dialects.h" #include "IRModule.h" #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" -#include +// TODO: Port this to operate only on the public PybindAdaptors.h +#include "PybindUtils.h" namespace py = pybind11; using namespace mlir; using namespace mlir::python; -namespace mlir { -namespace python { - -void populateDialectLinalgSubmodule(py::module &m) { +void mlir::python::populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) { @@ -34,6 +33,3 @@ void populateDialectLinalgSubmodule(py::module &m) { "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } - -} // namespace python -} // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLinalg.h b/mlir/lib/Bindings/Python/DialectLinalg.h deleted file mode 100644 index 3735dbf6f..000000000 --- a/mlir/lib/Bindings/Python/DialectLinalg.h +++ /dev/null @@ -1,22 +0,0 @@ -//===- DialectLinalg.h - Linalg dialect submodule of pybind module --------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_DIALECTLINALG_H -#define MLIR_BINDINGS_PYTHON_DIALECTLINALG_H - -#include "PybindUtils.h" - -namespace mlir { -namespace python { - -void populateDialectLinalgSubmodule(pybind11::module &m); - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_DIALECTLINALG_H diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp new file mode 100644 index 000000000..faf240e1a --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -0,0 +1,74 @@ +//===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Dialects.h" +#include "mlir-c/Dialect/SparseTensor.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python::adaptors; + +void mlir::python::populateDialectSparseTensorSubmodule( + py::module m, const py::module &irModule) { + auto attributeClass = irModule.attr("Attribute"); + + py::enum_(m, "DimLevelType") + .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) + .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) + .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); + + mlir_attribute_subclass(m, "EncodingAttr", + mlirAttributeIsASparseTensorEncodingAttr, + attributeClass) + .def_classmethod( + "get", + [](py::object cls, + std::vector dimLevelTypes, + llvm::Optional dimOrdering, int pointerBitWidth, + int indexBitWidth, MlirContext context) { + return cls(mlirSparseTensorEncodingAttrGet( + context, dimLevelTypes.size(), dimLevelTypes.data(), + dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, + pointerBitWidth, indexBitWidth)); + }, + py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), + py::arg("pointer_bit_width"), py::arg("index_bit_width"), + py::arg("context") = py::none(), + "Gets a sparse_tensor.encoding from parameters.") + .def_property_readonly( + "dim_level_types", + [](MlirAttribute self) { + std::vector ret; + for (int i = 0, + e = mlirSparseTensorEncodingGetNumDimLevelTypes(self); + i < e; ++i) + ret.push_back( + mlirSparseTensorEncodingAttrGetDimLevelType(self, i)); + return ret; + }) + .def_property_readonly( + "dim_ordering", + [](MlirAttribute self) -> llvm::Optional { + MlirAffineMap ret = + mlirSparseTensorEncodingAttrGetDimOrdering(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly( + "pointer_bit_width", + [](MlirAttribute self) { + return mlirSparseTensorEncodingAttrGetPointerBitWidth(self); + }) + .def_property_readonly("index_bit_width", [](MlirAttribute self) { + return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); + }); +} diff --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h new file mode 100644 index 000000000..301d53927 --- /dev/null +++ b/mlir/lib/Bindings/Python/Dialects.h @@ -0,0 +1,24 @@ +//===- Dialects.h - Declaration for dialect submodule factories -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIALECTS_H +#define MLIR_BINDINGS_PYTHON_DIALECTS_H + +#include + +namespace mlir { +namespace python { + +void populateDialectLinalgSubmodule(pybind11::module m); +void populateDialectSparseTensorSubmodule(pybind11::module m, + const pybind11::module &irModule); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIALECTS_H diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 60c282d1d..6e861c2f2 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -10,7 +10,7 @@ #include "PybindUtils.h" -#include "DialectLinalg.h" +#include "Dialects.h" #include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" @@ -98,8 +98,10 @@ PYBIND11_MODULE(_mlir, m) { m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); populateExecutionEngineSubmodule(executionEngineModule); - // Define and populate Linalg submodule. + // Define and populate dialect submodules. auto dialectsModule = m.def_submodule("dialects"); auto linalgModule = dialectsModule.def_submodule("linalg"); populateDialectLinalgSubmodule(linalgModule); + populateDialectSparseTensorSubmodule( + dialectsModule.def_submodule("sparse_tensor"), irModule); } From f56ebc6cad8af0b58beee803cbbf5b0d54e374a3 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 10 May 2021 17:42:24 +0000 Subject: [PATCH 040/915] [mlir][Python] Re-export cext sparse_tensor module to the public namespace. * This was left out of the previous commit accidentally. Differential Revision: https://reviews.llvm.org/D102183 --- mlir/python/mlir/_cext_loader.py | 5 ++++- mlir/python/mlir/dialects/sparse_tensor.py | 7 +++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 mlir/python/mlir/dialects/sparse_tensor.py diff --git a/mlir/python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py index 35847efa9..3a9cde380 100644 --- a/mlir/python/mlir/_cext_loader.py +++ b/mlir/python/mlir/_cext_loader.py @@ -45,7 +45,10 @@ def _reexport_cext(cext_module_name, target_module_name): """ import sys target_module = sys.modules[target_module_name] - source_module = getattr(_cext, cext_module_name) + submodule_names = cext_module_name.split(".") + source_module = _cext + for submodule_name in submodule_names: + source_module = getattr(source_module, submodule_name) for attr_name in dir(source_module): if not attr_name.startswith("__"): setattr(target_module, attr_name, getattr(source_module, attr_name)) diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py new file mode 100644 index 000000000..687d1e6cd --- /dev/null +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -0,0 +1,7 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._cext_loader import _reexport_cext +_reexport_cext("dialects.sparse_tensor", __name__) +del _reexport_cext From bf360be66d772fb7f336c5f5679f27621ed9cde0 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 10 May 2021 18:03:40 +0000 Subject: [PATCH 041/915] [mlir][Python] Finish adding RankedTensorType support for encoding. Differential Revision: https://reviews.llvm.org/D102184 --- mlir/include/mlir-c/BuiltinTypes.h | 4 ++++ mlir/lib/Bindings/Python/IRTypes.cpp | 16 +++++++++++++--- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 4 ++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 7d45452af..a677d4d36 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -203,6 +203,10 @@ MLIR_CAPI_EXPORTED MlirType mlirRankedTensorTypeGetChecked( MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding); +/// Gets the 'encoding' attribute from the ranked tensor type, returning a null +/// attribute if none. +MLIR_CAPI_EXPORTED MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type); + /// Creates an unranked tensor type with the given element type in the same /// context as the element type. The type is owned by the context. MLIR_CAPI_EXPORTED MlirType mlirUnrankedTensorTypeGet(MlirType elementType); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index b6875c76e..568cca160 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -338,10 +338,11 @@ class PyRankedTensorType c.def_static( "get", [](std::vector shape, PyType &elementType, + llvm::Optional &encodingAttr, DefaultingPyLocation loc) { - MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, encodingAttr); + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -355,8 +356,17 @@ class PyRankedTensorType } return PyRankedTensorType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), py::arg("loc") = py::none(), + py::arg("shape"), py::arg("element_type"), + py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); + c.def_property_readonly( + "encoding", + [](PyRankedTensorType &self) -> llvm::Optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return llvm::None; + return PyAttribute(self.getContext(), encoding); + }); } }; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 1e5fa8a32..d978f17b9 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -206,6 +206,10 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, unwrap(elementType), unwrap(encoding))); } +MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { + return wrap(unwrap(type).cast().getEncoding()); +} + MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { return wrap(UnrankedTensorType::get(unwrap(elementType))); } From 5b54ed08f91c5ae253da7b7490396247a254e9c9 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 10 May 2021 12:56:15 -0700 Subject: [PATCH 042/915] [mlir][linalg] remove the -now- obsolete sparse support in linalg All glue and clutter in the linalg ops has been replaced by proper sparse tensor type encoding. This code is no longer needed. Thanks to ntv@ for giving us a temporary home in linalg. So long, and thanks for all the fish. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D102098 --- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 4a037025d..85c77d52f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -89,12 +89,10 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - # TODO: Add support for sparse operands once there is a stable interface. - sparse_attr = None return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, capture_arg_mapping, indexing_maps_attr, - iterator_types_attr, sparse_attr) + iterator_types_attr) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, @@ -102,7 +100,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, outs: Sequence[Value] = (), captures: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs, captures=captures) @@ -113,8 +111,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, indexing_maps=indexing_maps_attr, iterator_types=iterator_types_attr, doc=None, # TODO: Make optional. - library_call=None, # TODO: Make optional. - sparse=sparse_attr) # TODO: Make optional. + library_call=None) # TODO: Make optional. # Construct the body. block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) @@ -141,7 +138,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, outs: Sequence[Value] = (), captures: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr, sparse_attr = \ + capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ prepare_common_structured_op(op_config, *ins, outs = outs, captures = captures) @@ -351,8 +348,8 @@ def _get_tensor_def_names( def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): if name in type_mapping: if type_mapping[name] != type: - raise ValueError(f"Cannot overwrite type mapping {name} = " - f"{type_mapping[name]} by type {type}") + raise ValueError(f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {type}") type_mapping[name] = type def _is_floating_point_type(t: Type) -> bool: From a62aa4b227b138a51191e250bab431edb0ef0661 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 12 May 2021 14:51:16 -0700 Subject: [PATCH 043/915] [mlir][sparse][capi][python] add sparse tensor passes First set of "boilerplate" to get sparse tensor passes available through CAPI and Python. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D102362 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 2 ++ mlir/lib/Bindings/Python/CMakeLists.txt | 8 ++++++ .../Bindings/Python/SparseTensorPasses.cpp | 22 ++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 2 ++ mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp | 26 +++++++++++++++++++ mlir/python/mlir/dialects/sparse_tensor.py | 5 ++++ 6 files changed, 65 insertions(+) create mode 100644 mlir/lib/Bindings/Python/SparseTensorPasses.cpp create mode 100644 mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 2615a1655..16d932c16 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -74,4 +74,6 @@ mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr); } #endif +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + #endif // MLIR_C_DIALECT_SPARSE_TENSOR_H diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 7dc1f64b4..575b9dbbd 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -33,6 +33,14 @@ add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasse ) add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) +add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSparseTensorPasses + INSTALL_DIR + python + SOURCES + SparseTensorPasses.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension) + add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses INSTALL_DIR python diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp new file mode 100644 index 000000000..2a8e2b802 --- /dev/null +++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp @@ -0,0 +1,22 @@ +//===- SparseTensorPasses.cpp - Pybind module for the SparseTensor passes -===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SparseTensor.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirSparseTensorPasses, m) { + m.doc() = "MLIR SparseTensor Dialect Passes"; + + // Register all SparseTensor passes on load. + mlirRegisterSparseTensorPasses(); +} diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 69371fd57..053fce30d 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -62,11 +62,13 @@ add_mlir_public_c_api_library(MLIRCAPIShape add_mlir_public_c_api_library(MLIRCAPISparseTensor SparseTensor.cpp + SparseTensorPasses.cpp PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR MLIRSparseTensor + MLIRSparseTensorTransforms ) add_mlir_public_c_api_library(MLIRCAPIStandard diff --git a/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp new file mode 100644 index 000000000..5b2ba4ca7 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SparseTensorPasses.cpp @@ -0,0 +1,26 @@ +//===- SparseTensorPasses.cpp - C API for SparseTensor Dialect Passes -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/CAPI/Pass.h" +#include "mlir/Dialect/SparseTensor/Transforms/Passes.h" +#include "mlir/Pass/Pass.h" + +// Must include the declarations as they carry important visibility attributes. +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" + +using namespace mlir; + +#ifdef __cplusplus +extern "C" { +#endif + +#include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.cpp.inc" + +#ifdef __cplusplus +} +#endif diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py index 687d1e6cd..59fd86021 100644 --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._cext_loader import _reexport_cext +from .._cext_loader import _load_extension + _reexport_cext("dialects.sparse_tensor", __name__) +_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses") + del _reexport_cext +del _load_extension From 3f52df7c6e437ebc2544fcdbc275d0c87810851f Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 5 May 2021 08:14:31 +0530 Subject: [PATCH 044/915] [MLIR][PYTHON] Provide opt level for ExecutionEngine Python binding Provide an option to specify optimization level when creating an ExecutionEngine via the MLIR JIT Python binding. Not only is the specified optimization level used for code generation, but all LLVM optimization passes at the optimization level are also run prior to machine code generation (akin to the mlir-cpu-runner tool). Default opt level continues to remain at level two (-O2). Contributions in part from Prashant Kumar as well. Differential Revision: https://reviews.llvm.org/D102551 --- mlir/include/mlir-c/ExecutionEngine.h | 9 ++++-- mlir/lib/Bindings/Python/ExecutionEngine.cpp | 11 +++++--- .../CAPI/ExecutionEngine/ExecutionEngine.cpp | 28 ++++++++++++++++--- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 4a5d6ad9f..289e8f73d 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -36,9 +36,12 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// expected to be "translatable" to LLVM IR (only contains operations in /// dialects that implement the `LLVMTranslationDialectInterface`). The module /// ownership stays with the client and can be destroyed as soon as the call -/// returns. -/// TODO: figure out options (optimization level, etc.). -MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op); +/// returns. `optLevel` is the optimization level to be used for transformation +/// and code generation. LLVM passes at `optLevel` are run before code +/// generation. +/// TODO: figure out other options. +MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, + int optLevel); /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp index b5c8dde75..38cf6b2ca 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -59,17 +59,20 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) { // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine") - .def(py::init<>([](PyModule &module) { + .def(py::init<>([](PyModule &module, int optLevel) { MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module.get()); + mlirExecutionEngineCreate(module.get(), optLevel); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); return new PyExecutionEngine(executionEngine); }), + py::arg("module"), py::arg("opt_level") = 2, "Create a new ExecutionEngine instance for the given Module. The " - "module must " - "contain only dialects that can be translated to LLVM.") + "module must contain only dialects that can be translated to LLVM. " + "Perform transformations and code generation at the optimization " + "level `opt_level` if specified, or otherwise at the default " + "level of two (-O2).") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule) .def("_testing_release", &PyExecutionEngine::release, diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 36f24ed88..dfde38aee 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -10,22 +10,42 @@ #include "mlir/CAPI/ExecutionEngine.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" using namespace mlir; -extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op) { - static bool init_once = [] { +extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, + int optLevel) { + static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); return true; }(); - (void)init_once; + (void)initOnce; mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext()); - auto jitOrError = ExecutionEngine::create(unwrap(op)); + + auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); + if (!tmBuilderOrError) { + llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; + return MlirExecutionEngine{nullptr}; + } + auto tmOrError = tmBuilderOrError->createTargetMachine(); + if (!tmOrError) { + llvm::errs() << "Failed to create a TargetMachine for the host\n"; + return MlirExecutionEngine{nullptr}; + } + + // Create a transformer to run all LLVM optimization passes at the + // specified optimization level. + auto llvmOptLevel = static_cast(optLevel); + auto transformer = mlir::makeLLVMPassesTransformer( + /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); + auto jitOrError = ExecutionEngine::create( + unwrap(op), /*llvmModuleBuilder=*/{}, transformer, llvmOptLevel); if (!jitOrError) { consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; From e5e2a1ae23005c52d99da21f10a51d0ddc1c5d4e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 May 2021 10:14:02 +0000 Subject: [PATCH 045/915] Add `mlirModuleFromOperation` to C API At the moment `MlirModule`s can be converted to `MlirOperation`s, but not the other way around (at least not without going around the C API). This makes it impossible to e.g. run passes over a `ModuleOp` created through `mlirOperationCreate`. Reviewed By: nicolasvasilache, mehdi_amini Differential Revision: https://reviews.llvm.org/D102497 --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1b243165c..638eea9b8 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -209,6 +209,10 @@ MLIR_CAPI_EXPORTED void mlirModuleDestroy(MlirModule module); /// Views the module as a generic operation. MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); +/// Views the generic operation as a module. +/// The returned module is null when the input operation was not a ModuleOp. +MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 4e2183516..ebabd6899 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -181,6 +181,10 @@ MlirOperation mlirModuleGetOperation(MlirModule module) { return wrap(unwrap(module).getOperation()); } +MlirModule mlirModuleFromOperation(MlirOperation op) { + return wrap(dyn_cast(unwrap(op))); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// From 2828dd39991508af70cef099902d1d2d55e860ee Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 19 May 2021 13:10:28 +0000 Subject: [PATCH 046/915] [mir][Python][linalg] Support OpDSL extensions in C++. The patch extends the yaml code generation to support the following new OpDSL constructs: - captures - constants - iteration index accesses - predefined types These changes have been introduced by revision https://reviews.llvm.org/D101364. Differential Revision: https://reviews.llvm.org/D102075 --- .../linalg/opdsl/lang/comprehension.py | 46 ++++++++---- .../mlir/dialects/linalg/opdsl/lang/config.py | 75 ++++++++++--------- .../dialects/linalg/opdsl/lang/emitter.py | 66 ++++++++-------- .../dialects/linalg/opdsl/lang/scalar_expr.py | 34 +++++---- .../linalg/opdsl/ops/core_named_ops.py | 60 +++++++++++---- 5 files changed, 164 insertions(+), 117 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 9b93d33b3..2ac0641a3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union from mlir import ir as _ir @@ -36,8 +36,8 @@ def _get_all_dim_defs(self) -> Set[DimDef]: results = set() def visit_dim_def(dim_def): - if isinstance(dim_def, DimDef): - results.add(dim_def) + if isinstance(dim_def, DimDef): + results.add(dim_def) def visit_affine_exprs(expr): if isinstance(expr, TensorUse): @@ -52,23 +52,29 @@ def visit_affine_exprs(expr): def collect_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" + def visit_tensor_use(expr): if isinstance(expr, TensorUse): uses.add(expr) + self.visit_tensor_exprs(visit_tensor_use) def collect_indices(self, indices: Set["index"]): """Collects all index accesses reachable through this expression.""" + def visit_index(expr): if isinstance(expr, index): indices.add(expr) + self.visit_tensor_exprs(visit_index) def collect_captures(self, captures: Set["CaptureDef"]): """Collects all CaptureDefs reachable through this expression.""" + def visit_capture_def(expr): if isinstance(expr, CaptureDef): captures.add(expr) + self.visit_tensor_exprs(visit_capture_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": @@ -159,8 +165,8 @@ def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"): def __getitem__(self, dims) -> TensorUse: assert self.owner, "TensorDef is not attached to an op" - state = AffineBuildState(global_state=self.owner._affine_state, - allow_new_symbols=False) + state = AffineBuildState( + global_state=self.owner._affine_state, allow_new_symbols=False) if not isinstance(dims, tuple): dims = (dims,) # Handle single subscript case. # Special case: (None) is a 0d-scalar use. @@ -196,6 +202,7 @@ def __repr__(self): return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " f"shape={self.shape})") + class CaptureDef(TensorExpression): """Defines an SSA value captured by the operation. @@ -226,6 +233,7 @@ def to_scalar_expression(self) -> ScalarExpression: def __repr__(self): return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})") + class Comprehension: """Represents a single comprehension.""" @@ -334,23 +342,27 @@ def visit_tensor_exprs(self, callback): def __repr__(self): return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" + class const(TensorExpression): """Returns the given constant floating point or integer value.""" - def __init__(self, type_var: TypeVar, value: Any): - if not isinstance(type_var, TypeVar): - raise ValueError(f"const requires a TypeVar. Got: {repr(type_var)}") - if not (isinstance(value, float) or isinstance(value, int)): - raise ValueError(f"const requires int or float. Got: {type(value)}") - self.type_var = type_var - self.value = value + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) + else: + raise ValueError(f"const requires int or float. Got: {type(value)}") def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.type_var, self.value).expr() + return ScalarConst(self.value).expr() def __repr__(self): return f"const({self.type_var}, {self.value})" + class index(TensorExpression): """Returns the iteration index for a given dimension name. @@ -358,7 +370,7 @@ class index(TensorExpression): domain of the operation. """ - def __init__(self, dim : DimDef): + def __init__(self, dim: DimDef): self.dim_def = dim self.dim = -1 @@ -433,7 +445,8 @@ class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" - def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): + def __init__(self, name: str, cpp_class_name: Optional[str], + doc: Optional[str]): self.name = name self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc @@ -457,7 +470,8 @@ def __init__(self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None): - self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) + self.metadata = OpMetadataDef( + name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_tensors = dict() # type: Dict[str, TensorDef] self.registered_captures = dict() # type: Dict[str, CaptureDef] self.comprehensions = list() # type: List[Comprehension] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index a67d18cc3..9026e2030 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -11,7 +11,7 @@ to helpers on the comprehension objects themselves. """ -from typing import Any, Dict, Optional +from typing import Dict, Optional from mlir import ir as _ir @@ -70,6 +70,7 @@ def to_yaml_custom_dict(self): def __repr__(self): return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" + class CaptureDefConfig(YAMLObject): """Wrapper around a CaptureDef.""" yaml_tag = "LinalgCaptureDef" @@ -113,8 +114,7 @@ def to_yaml_custom_dict(self): class LinalgStructuredOpConfig(YAMLObject): - """Configuration for metadata sufficient to construct a linalg single - contraction named op.""" + """Configuration for metadata sufficient to construct a linalg named op.""" yaml_tag = "!LinalgStructuredOpConfig" @@ -156,8 +156,8 @@ def __init__(self, for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) for cdef in self.tensor_args.values(): - cdef.shape_map = self._normalize_affine_map(cdef.shape_map, - with_dims=False) + cdef.shape_map = self._normalize_affine_map( + cdef.shape_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. @@ -198,8 +198,8 @@ def __init__(self, for index in collected_indices: if index.dim_def.dimname not in self.affine_state.all_dims: raise ValueError( - f"The dimension {index.dim.dimname} is not part of the iteration " - f"domain {self.affine_state.all_dims}") + f"The dimension {index.dim.dimname} is not part of the iteration " + f"domain {self.affine_state.all_dims}") index.resolve_dimension_name(self.affine_state) # Generate the scalar assignments (used to build a body). @@ -210,18 +210,21 @@ def __init__(self, @property def ordered_tensor_args(self) -> Sequence[TensorDefConfig]: - return sorted(self.tensor_args.values(), - key=lambda tdc: tdc.tensor_def.registered_index) + return sorted( + self.tensor_args.values(), + key=lambda tdc: tdc.tensor_def.registered_index) @property def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: - return sorted(self.uses.values(), - key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) + return sorted( + self.uses.values(), + key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) @property def ordered_capture_args(self) -> Sequence[CaptureDefConfig]: - return sorted(self.capture_args.values(), - key=lambda cdc: cdc.capture_def.registered_index) + return sorted( + self.capture_args.values(), + key=lambda cdc: cdc.capture_def.registered_index) @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: @@ -252,15 +255,14 @@ def add_tensor_arg(self, tensor_def: TensorDef): if tensor_def in self.tensor_args: return with self.context: - local_state = AffineBuildState(global_state=self.affine_state, - allow_new_dims=False) + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_dims=False) exprs = [] for expr in tensor_def.shape: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - indexing_map = _ir.AffineMap.get(dim_count=0, - symbol_count=local_state.symbol_count, - exprs=exprs) + indexing_map = _ir.AffineMap.get( + dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) def_config = TensorDefConfig(tensor_def, indexing_map) self.tensor_args[tensor_def] = def_config @@ -269,15 +271,16 @@ def add_use(self, tensor_use: TensorUse): if tensor_use in self.uses: return with self.context: - local_state = AffineBuildState(global_state=self.affine_state, - allow_new_symbols=False) + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False) exprs = [] for expr in tensor_use.indices: exprs.append(expr.build(state=local_state)) assert local_state.local_symbol_count == 0 - indexing_map = _ir.AffineMap.get(dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) + indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs) use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config @@ -299,16 +302,15 @@ def _normalize_affine_map(self, exprs=list(affine_map.results)) def to_yaml_custom_dict(self): - self_dict = dict( - args=self.ordered_tensor_args, - captures=self.ordered_capture_args, - # TODO: Refactor the hierarchy internally when supporting more - # than static (preserving this serialized form). - indexing_maps=LinalgIndexingMapsConfig( - static_indexing_maps=self.indexing_maps), - iterator_types=self.iterator_types, - assignments=self.assignments, - ) + self_dict = dict(args=self.ordered_tensor_args) + if self.ordered_capture_args: + self_dict["captures"] = self.ordered_capture_args + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments return self_dict def __repr__(self): @@ -359,9 +361,10 @@ def from_linalg_op_def( assert len( tc_op_def.comprehensions) == 1, "Only one comprehension supported" return [ - LinalgOpConfig(tc_op_def.metadata, - structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], context)), + LinalgOpConfig( + tc_op_def.metadata, + structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0], + context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 85c77d52f..5538a9e42 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Any, Dict, Sequence +from typing import Dict, Sequence from mlir.ir import * from mlir.dialects import linalg @@ -19,16 +19,17 @@ "emit_named_structured_op", ] -def isa(cls : Type, ty : Type): + +def isa(cls: Type, ty: Type): try: cls(ty) return True except ValueError: return False + def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Sequence[Value], + *ins: Value, outs: Sequence[Value], captures: Sequence[Value]): all_arg_defs = op_config.ordered_tensor_args in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] @@ -82,11 +83,13 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, # Emit the generic op. # TODO: Support emission of pure memref form. - indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) - # TODO: linalg verification does not currently allow symbols. - # Compress them for now. - for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, Context.current)]) + indexing_maps_attr = ArrayAttr.get([ + AffineMapAttr.get(am) + # TODO: linalg verification does not currently allow symbols. + # Compress them for now. + for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, + Context.current) + ]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) @@ -144,7 +147,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current - fully_qualified_name = 'linalg.' + op_name + fully_qualified_name = "linalg." + op_name if (not ctx.is_registered_operation(fully_qualified_name) or not op_class_name in linalg.__dict__.keys()): raise NotImplementedError( @@ -156,7 +159,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. - named_op.operation.attributes["linalg.memoized_indexing_maps"] = indexing_maps_attr + named_op.operation.attributes[ + "linalg.memoized_indexing_maps"] = indexing_maps_attr # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. if len(result_types) == 1: @@ -168,8 +172,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" - def __init__(self, - type_mapping: Dict[str, Type], + def __init__(self, type_mapping: Dict[str, Type], block_arg_mapping: Dict[str, Value], capture_arg_mapping: Dict[str, Value]): self.type_mapping = type_mapping @@ -195,12 +198,16 @@ def expression(self, expr: ScalarExpression) -> Value: try: return self.capture_arg_mapping[expr.scalar_capture.capture] except KeyError: - raise ValueError(f"Capture {expr.scalar_capture.capture} is not bound for " - f"this structured op.") + raise ValueError( + f"Capture {expr.scalar_capture.capture} is not bound for " + f"this structured op.") elif expr.scalar_const: - return self.constant(expr.scalar_const.type_var.name, expr.scalar_const.value) + value_attr = Attribute.parse(expr.scalar_const.value) + return std.ConstantOp(value_attr.type, value_attr).result elif expr.scalar_index: - return self.index(expr.scalar_index.dim) + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim) + return linalg.IndexOp(IndexType.get(), dim_attr).result elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -217,25 +224,6 @@ def expression(self, expr: ScalarExpression) -> Value: return self.cast(expr.symbolic_cast.to_type.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def constant(self, type_var_name: str, value: Any) -> Value: - try: - type = self.type_mapping[type_var_name] - except KeyError: - raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mappings.keys()}") - try: - if(_is_floating_point_type(type)): - return std.ConstantOp(type, FloatAttr.get(type, float(value))).result - elif(_is_integer_type(type)): - return std.ConstantOp(type, IntegerAttr.get(type, int(value))).result - except ValueError: - raise ValueError(f"Unable to cast value {value} to type {type}") - raise NotImplementedError(f"Unimplemented constant type {type}") - - def index(self, dim: int) -> Value: - dim_attr = IntegerAttr.get(IntegerType.get_signless(64), dim) - return linalg.IndexOp(IndexType.get(), dim_attr).result - def cast(self, type_var_name: str, operand: Value) -> Value: try: to_type = self.type_mapping[type_var_name] @@ -248,6 +236,7 @@ def cast(self, type_var_name: str, operand: Value) -> Value: return self._cast_to_integer(to_type, operand) elif _is_floating_point_type(to_type): return self._cast_to_floating_point(to_type, operand) + def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type @@ -345,6 +334,7 @@ def _get_tensor_def_names( *tensor_def_configs: TensorDefConfig) -> Sequence[str]: return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] + def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): if name in type_mapping: if type_mapping[name] != type: @@ -352,18 +342,22 @@ def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): f"{type_mapping[name]} by type {type}") type_mapping[name] = type + def _is_floating_point_type(t: Type) -> bool: # TODO: Create a FloatType in the Python API and implement the switch # there. return (F64Type.isinstance(t) or F32Type.isinstance(t) or F16Type.isinstance(t) or BF16Type.isinstance(t)) + def _is_integer_type(t: Type) -> bool: return IntegerType.isinstance(t) + def _is_index_type(t: Type) -> bool: return IndexType.isinstance(t) + def _get_floating_point_width(t: Type) -> int: # TODO: Create a FloatType in the Python API and implement the switch # there. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index bb1938d71..2cc426b62 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -13,7 +13,7 @@ can be easily consumed from the C++ side, not necessarily for ergonomics. """ -from typing import Any, Optional, Sequence +from typing import Optional, Sequence from .yaml_helper import * from .types import * @@ -56,6 +56,7 @@ def expr(self) -> "ScalarExpression": def __repr__(self): return f"(ScalarArg({self.arg})" + class ScalarCapture: """A type of ScalarExpression that references a named capture.""" @@ -68,23 +69,24 @@ def expr(self) -> "ScalarExpression": def __repr__(self): return f"(ScalarCapture({self.capture})" + class ScalarConst: """A type of ScalarExpression representing a constant.""" - def __init__(self, type_var: TypeVar, value: Any): - self.type_var = type_var + def __init__(self, value: str): self.value = value def expr(self) -> "ScalarExpression": return ScalarExpression(scalar_const=self) def __repr__(self): - return f"(ScalarConst({self.type_var}, {self.value})" + return f"(ScalarConst({self.value})" + class ScalarIndex: """A type of ScalarExpression accessing an iteration index.""" - def __init__(self, dim : int): + def __init__(self, dim: int): self.dim = dim def expr(self) -> "ScalarExpression": @@ -93,9 +95,9 @@ def expr(self) -> "ScalarExpression": def __repr__(self): return f"(ScalarIndex({self.dim})" + class ScalarSymbolicCast: - """A type of ScalarExpression that symbolically casts an operand to a TypeVar. - """ + """A type of ScalarExpression that symbolically casts an operand to a TypeVar.""" def __init__(self, to_type: TypeVar, operand: "ScalarExpression"): self.to_type = to_type @@ -142,25 +144,27 @@ def __init__(self, def to_yaml_custom_dict(self): if self.scalar_apply: - return dict(scalar_apply=dict( - fn_name=self.scalar_apply.fn_name, - operands=list(self.scalar_apply.operands), - )) + return dict( + scalar_apply=dict( + fn_name=self.scalar_apply.fn_name, + operands=list(self.scalar_apply.operands), + )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_capture: return dict(scalar_capture=self.scalar_capture.capture) elif self.scalar_const: - return dict(scalar_const=dict(type_var=self.scalar_const.type_var.name, - attributes=[self.scalar_const.value])) + return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: return dict(scalar_index=self.scalar_index.dim) elif self.symbolic_cast: # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more # generic vs having a special form. - return dict(symbolic_cast=dict(type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand])) + return dict( + symbolic_cast=dict( + type_var=self.symbolic_cast.to_type.name, + operands=[self.symbolic_cast.operand])) else: raise ValueError(f"Unexpected ScalarExpression type: {self}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b52a0e2d6..ad7996345 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -7,9 +7,10 @@ @linalg_structured_op -def matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -20,9 +21,10 @@ def matmul(A=TensorDef(T1, S.M, S.K), @linalg_structured_op -def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): +def batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True)): """Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -33,9 +35,10 @@ def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), @linalg_structured_op -def matvec(A=TensorDef(T1, S.M, S.N), - y=TensorDef(T2, S.N), - x=TensorDef(U, S.M, output=True)): +def matvec( + A=TensorDef(T1, S.M, S.N), + y=TensorDef(T2, S.N), + x=TensorDef(U, S.M, output=True)): """Performs a matrix-vector multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -46,9 +49,10 @@ def matvec(A=TensorDef(T1, S.M, S.N), @linalg_structured_op -def vecmat(y=TensorDef(T1, S.M), - A=TensorDef(T2, S.M, S.N), - x=TensorDef(U, S.N, output=True)): +def vecmat( + y=TensorDef(T1, S.M), + A=TensorDef(T2, S.M, S.N), + x=TensorDef(U, S.N, output=True)): """Performs a vector-matrix multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -59,8 +63,8 @@ def vecmat(y=TensorDef(T1, S.M), @linalg_structured_op -def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, - output=True)): +def dot( + A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): """Performs a dot product of two vectors to a scalar result. Numeric casting is performed on the operands to the inner multiply, promoting @@ -68,3 +72,31 @@ def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, """ implements(ContractionOpInterface) C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + + +@linalg_structured_op +def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + + Note: The captures are hard-coded till there is capture support on the C++ + side. + """ + min = cast(F64, const(-1000)) + max = cast(F64, const(+1000)) + seed = cast(I32, const(42)) + multiplier = cast(I32, const(1103515245)) + increment = cast(I32, const(12345)) + rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = cast(F64, const(2.3283064e-10)) + offset = cast(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) From 80a1a0dd18daa8dfafe40fa2476a99e23c7eda7d Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Wed, 19 May 2021 11:58:42 -0700 Subject: [PATCH 047/915] [mlir][CAPI] Expose [u]int8 DenseElementsAttr. Also, fix a small typo where the "unsigned" splat variants were not being created with an unsigned type. Differential Revision: https://reviews.llvm.org/D102797 --- mlir/include/mlir-c/BuiltinAttributes.h | 16 ++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 34 +++++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index c85825c8d..247de5cc0 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -313,6 +313,10 @@ mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element); MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element); +MLIR_CAPI_EXPORTED MlirAttribute +mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element); @@ -330,6 +334,10 @@ mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element); /// data element type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBoolGet( MlirType shapedType, intptr_t numElements, const int *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get( + MlirType shapedType, intptr_t numElements, const uint8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get( + MlirType shapedType, intptr_t numElements, const int8_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get( MlirType shapedType, intptr_t numElements, const uint32_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get( @@ -364,6 +372,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED int8_t +mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr); MLIR_CAPI_EXPORTED uint32_t @@ -383,6 +395,10 @@ mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr); /// contained by the given dense elements attribute. MLIR_CAPI_EXPORTED bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint32_t diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 7580786de..93a6eff99 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -341,6 +341,16 @@ MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, return wrap( DenseElementsAttr::get(unwrap(shapedType).cast(), element)); } +MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, + uint8_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} +MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, + int8_t element) { + return wrap( + DenseElementsAttr::get(unwrap(shapedType).cast(), element)); +} MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { return wrap( @@ -390,6 +400,16 @@ static MlirAttribute getDenseAttribute(MlirType shapedType, llvm::makeArrayRef(elements, numElements))); } +MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, + intptr_t numElements, + const uint8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, + intptr_t numElements, + const int8_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { @@ -452,6 +472,12 @@ MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } +int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} +uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { + return unwrap(attr).cast().getSplatValue(); +} int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); } @@ -482,6 +508,14 @@ bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); } +int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} +uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { + return *(unwrap(attr).cast().getValues().begin() + + pos); +} int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return *(unwrap(attr).cast().getValues().begin() + pos); From 35a67d5adc37c2ded81deeadaef2450a103d5a04 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 20 May 2021 17:51:53 +0900 Subject: [PATCH 048/915] [mlir] Add Python bindings for vector dialect Also add a minimal test case for vector.print. Differential Revision: https://reviews.llvm.org/D102826 --- mlir/python/mlir/dialects/CMakeLists.txt | 5 +++++ mlir/python/mlir/dialects/VectorOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/vector.py | 5 +++++ 3 files changed, 25 insertions(+) create mode 100644 mlir/python/mlir/dialects/VectorOps.td create mode 100644 mlir/python/mlir/dialects/vector.py diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt index 31a4ee55b..cad3bb710 100644 --- a/mlir/python/mlir/dialects/CMakeLists.txt +++ b/mlir/python/mlir/dialects/CMakeLists.txt @@ -45,6 +45,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps DIALECT_NAME tensor) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonVectorOps + TD_FILE VectorOps.td + DIALECT_NAME vector) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonVectorOps) + ################################################################################ # Installation. ################################################################################ diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td new file mode 100644 index 000000000..b06668bdf --- /dev/null +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -0,0 +1,15 @@ +//===-- VectorOps.td - Entry point for VectorOps bind ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR_OPS +#define PYTHON_BINDINGS_VECTOR_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Vector/VectorOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py new file mode 100644 index 000000000..610c0b204 --- /dev/null +++ b/mlir/python/mlir/dialects/vector.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._vector_ops_gen import * From 93e8d92aefa83f43397f8c2e7c245a37daef0f2f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 20 May 2021 15:05:05 +0000 Subject: [PATCH 049/915] [mlir][Linalg] NFC - Drop Linalg EDSC usage Drop the Linalg dialect EDSC subdirectory and update all uses. Differential Revision: https://reviews.llvm.org/D102848 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 6f6e090d7..21e4e2ce8 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,7 +8,6 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" using namespace mlir; @@ -38,12 +37,11 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, for (auto t : linalgOp.getShapedOperandTypes()) argTypes.push_back(getElementTypeOrSelf(t)); - OpBuilder b(op->getContext()); + ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); b.setInsertionPointToStart(body); - mlir::edsc::ScopedContext scope(b, op->getLoc()); - fun(*body, captures); + fun(b, *body, captures); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From dfed13646804634082cafedd3b1ad4eb44aa97ff Mon Sep 17 00:00:00 2001 From: John Demme Date: Sun, 23 May 2021 20:37:55 -0700 Subject: [PATCH 050/915] [MLIR] [Python] Add Operation.parent Attribute to get the parent operation of an operation. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D102981 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d11edb1c6..3f08522b9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2113,6 +2113,10 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def_property_readonly("parent", + [](PyOperation &self) { + return self.getParentOperation().getObject(); + }) .def("erase", &PyOperation::erase) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) From 928ecd4e141e85f26f897643ca3fc2ddb27598b9 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 24 May 2021 16:41:38 +0000 Subject: [PATCH 051/915] Enable MLIR Python bindings for TOSA. Differential Revision: https://reviews.llvm.org/D103035 --- mlir/python/mlir/dialects/CMakeLists.txt | 5 +++++ mlir/python/mlir/dialects/TosaOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/tosa.py | 5 +++++ 3 files changed, 25 insertions(+) create mode 100644 mlir/python/mlir/dialects/TosaOps.td create mode 100644 mlir/python/mlir/dialects/tosa.py diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt index cad3bb710..5eeb6d628 100644 --- a/mlir/python/mlir/dialects/CMakeLists.txt +++ b/mlir/python/mlir/dialects/CMakeLists.txt @@ -45,6 +45,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps DIALECT_NAME tensor) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonTosaOps + TD_FILE TosaOps.td + DIALECT_NAME tosa) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTosaOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonVectorOps TD_FILE VectorOps.td DIALECT_NAME vector) diff --git a/mlir/python/mlir/dialects/TosaOps.td b/mlir/python/mlir/dialects/TosaOps.td new file mode 100644 index 000000000..d906bad7c --- /dev/null +++ b/mlir/python/mlir/dialects/TosaOps.td @@ -0,0 +1,15 @@ +//===-- TosaOps.td - Entry point for TosaOps bind ----------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TOSA_OPS +#define PYTHON_BINDINGS_TOSA_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Tosa/IR/TosaOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/tosa.py b/mlir/python/mlir/dialects/tosa.py new file mode 100644 index 000000000..aebda742f --- /dev/null +++ b/mlir/python/mlir/dialects/tosa.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._tosa_ops_gen import * From 8b129ea2b1053fc003bd986f79f283a1c92e4c93 Mon Sep 17 00:00:00 2001 From: George Date: Mon, 24 May 2021 11:52:41 -0700 Subject: [PATCH 052/915] Surface clone APIs in CAPI Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D102987 --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 638eea9b8..b0866385c 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -326,6 +326,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Creates a deep copy of an operation. The operation is not inserted and +/// ownership is transferred to the caller. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); + /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ebabd6899..2721efde3 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -313,6 +313,10 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) { return result; } +MlirOperation mlirOperationClone(MlirOperation op) { + return wrap(unwrap(op)->clone()); +} + void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } bool mlirOperationEqual(MlirOperation op, MlirOperation other) { From 88640bd6a245f772d15323fbbac0f9343981c05d Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Wed, 26 May 2021 12:44:33 -0700 Subject: [PATCH 053/915] [mlir][python] Provide "all passes" registration module in Python Currently, passes are registered on a per-dialect basis, which provides the smallest footprint obviously. But for prototyping and experimentation, a convenience "all passes" module is provided, which registers all known MLIR passes in one run. Usage in Python: import mlir.all_passes_registration Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D103130 --- mlir/include/mlir-c/Registration.h | 3 +++ .../Bindings/Python/AllPassesRegistration.cpp | 22 +++++++++++++++++++ mlir/lib/Bindings/Python/CMakeLists.txt | 10 ++++++++- mlir/lib/CAPI/Registration/Registration.cpp | 3 +++ .../mlir/all_passes_registration/__init__.py | 8 +++++++ 5 files changed, 45 insertions(+), 1 deletion(-) create mode 100644 mlir/lib/Bindings/Python/AllPassesRegistration.cpp create mode 100644 mlir/python/mlir/all_passes_registration/__init__.py diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h index 4cfc96719..e8604329f 100644 --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -60,6 +60,9 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirContext context); /// Register all translations to LLVM IR for dialects that can support it. MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); +/// Register all compiler passes of MLIR. +MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp b/mlir/lib/Bindings/Python/AllPassesRegistration.cpp new file mode 100644 index 000000000..f595b20ba --- /dev/null +++ b/mlir/lib/Bindings/Python/AllPassesRegistration.cpp @@ -0,0 +1,22 @@ +//===- AllPassesRegistration.cpp - Pybind module to register all passes ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Registration.h" + +#include + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirAllPassesRegistration, m) { + m.doc() = "MLIR All Passes Convenience Module"; + + // Register all passes on load. + mlirRegisterAllPasses(); +} diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 575b9dbbd..173cf48c0 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -25,6 +25,14 @@ add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) add_subdirectory(Transforms) add_subdirectory(Conversions) +add_mlir_python_extension(MLIRAllPassesRegistrationBindingsPythonExtension _mlirAllPassesRegistration + INSTALL_DIR + python + SOURCES + AllPassesRegistration.cpp +) +add_dependencies(MLIRBindingsPythonExtension MLIRAllPassesRegistrationBindingsPythonExtension) + add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasses INSTALL_DIR python @@ -37,7 +45,7 @@ add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSpa INSTALL_DIR python SOURCES - SparseTensorPasses.cpp + SparseTensorPasses.cpp ) add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension) diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/Registration/Registration.cpp index dea782453..4ac300d1f 100644 --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ b/mlir/lib/CAPI/Registration/Registration.cpp @@ -10,6 +10,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" void mlirRegisterAllDialects(MlirContext context) { @@ -21,3 +22,5 @@ void mlirRegisterAllDialects(MlirContext context) { void mlirRegisterAllLLVMTranslations(MlirContext context) { mlir::registerLLVMDialectTranslation(*unwrap(context)); } + +void mlirRegisterAllPasses() { mlir::registerAllPasses(); } diff --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py new file mode 100644 index 000000000..cf3367cfe --- /dev/null +++ b/mlir/python/mlir/all_passes_registration/__init__.py @@ -0,0 +1,8 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._cext_loader import _load_extension + +_cextAllPasses = _load_extension("_mlirAllPassesRegistration") +del _load_extension From 94cdd541e59bf09a73f57edee456d070d54e5ba4 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 27 May 2021 13:33:32 -0700 Subject: [PATCH 054/915] [mlir][capi] fix build issue with "all passes" registration Some builds exposed missing dependences on trafo/conv passes. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D103283 --- mlir/lib/CAPI/Registration/CMakeLists.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/CAPI/Registration/CMakeLists.txt b/mlir/lib/CAPI/Registration/CMakeLists.txt index 417140ac0..b4a8650b3 100644 --- a/mlir/lib/CAPI/Registration/CMakeLists.txt +++ b/mlir/lib/CAPI/Registration/CMakeLists.txt @@ -1,5 +1,7 @@ # Dialect registration. get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) add_mlir_public_c_api_library(MLIRCAPIRegistration Registration.cpp @@ -7,4 +9,6 @@ add_mlir_public_c_api_library(MLIRCAPIRegistration MLIRCAPIIR MLIRLLVMToLLVMIRTranslation ${dialect_libs} + ${translation_libs} + ${conversion_libs} ) From 411fa6cf803d81c33d23357a8f18ccb8a81171c0 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 3 Jun 2021 15:23:14 +0000 Subject: [PATCH 055/915] [mlir][linalg] Cleanup LinalgOp usage in capi. Replace the uses of deprecated Structured Op Interface methods in Linalg.cpp. This patch is based on https://reviews.llvm.org/D103394. Differential Revision: https://reviews.llvm.org/D103619 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 21e4e2ce8..be0d54488 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -34,8 +34,8 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, SmallVector argTypes; auto linalgOp = cast(op); - for (auto t : linalgOp.getShapedOperandTypes()) - argTypes.push_back(getElementTypeOrSelf(t)); + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) + argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); From b54cf1423d3d70d727b70fd391c7786cf0a21257 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sat, 5 Jun 2021 11:38:31 -0700 Subject: [PATCH 056/915] [Core] Add Twine support for StringAttr and Identifier. NFC. This is both more efficient and more ergonomic than going through an std::string, e.g. when using llvm::utostr and in string concat cases. Unfortunately we can't just overload ::get(). This causes an ambiguity because both twine and stringref implicitly convert from std::string. Differential Revision: https://reviews.llvm.org/D103754 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 93a6eff99..3ec1b73c0 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -70,9 +70,8 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, SmallVector attributes; attributes.reserve(numElements); for (intptr_t i = 0; i < numElements; ++i) - attributes.emplace_back( - Identifier::get(unwrap(elements[i].name), unwrap(ctx)), - unwrap(elements[i].attribute)); + attributes.emplace_back(unwrap(elements[i].name), + unwrap(elements[i].attribute)); return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } From cb20614736c4c1c9d18e5e8789c585ab3267d01c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 10 Jun 2021 19:00:34 +0200 Subject: [PATCH 057/915] [mlir] Provide minimal Python bindings for the math dialect Reviewed By: ulysseB Differential Revision: https://reviews.llvm.org/D104045 --- mlir/python/mlir/dialects/CMakeLists.txt | 5 +++++ mlir/python/mlir/dialects/MathOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/math.py | 5 +++++ 3 files changed, 25 insertions(+) create mode 100644 mlir/python/mlir/dialects/MathOps.td create mode 100644 mlir/python/mlir/dialects/math.py diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt index 5eeb6d628..3c0434475 100644 --- a/mlir/python/mlir/dialects/CMakeLists.txt +++ b/mlir/python/mlir/dialects/CMakeLists.txt @@ -25,6 +25,11 @@ add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps DEPENDS LinalgOdsGen) add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) +add_mlir_dialect_python_bindings(MLIRBindingsPythonMathOps + TD_FILE MathOps.td + DIALECT_NAME math) +add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMathOps) + add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps TD_FILE MemRefOps.td DIALECT_NAME memref) diff --git a/mlir/python/mlir/dialects/MathOps.td b/mlir/python/mlir/dialects/MathOps.td new file mode 100644 index 000000000..03d1fdef0 --- /dev/null +++ b/mlir/python/mlir/dialects/MathOps.td @@ -0,0 +1,15 @@ +//===-- MathOps.td - Entry point for MathOps bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MATH_OPS +#define PYTHON_BINDINGS_MATH_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Math/IR/MathOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/math.py b/mlir/python/mlir/dialects/math.py new file mode 100644 index 000000000..f082bf461 --- /dev/null +++ b/mlir/python/mlir/dialects/math.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._math_ops_gen import * From fa5edf9c21f65fa3cdcff789d88671bd76b2ff40 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Sat, 5 Jun 2021 19:54:34 +0530 Subject: [PATCH 058/915] [MLIR] Execution engine python binding support for shared libraries Add support to Python bindings for the MLIR execution engine to load a specified list of shared libraries - for eg. to use MLIR runtime utility libraries. Differential Revision: https://reviews.llvm.org/D104009 --- mlir/include/mlir-c/ExecutionEngine.h | 9 ++++++--- mlir/lib/Bindings/Python/ExecutionEngine.cpp | 14 ++++++++++---- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 14 ++++++++++---- 3 files changed, 26 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 289e8f73d..bb454529b 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -38,10 +38,13 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// ownership stays with the client and can be destroyed as soon as the call /// returns. `optLevel` is the optimization level to be used for transformation /// and code generation. LLVM passes at `optLevel` are run before code -/// generation. +/// generation. The number and array of paths corresponding to shared libraries +/// that will be loaded are specified via `numPaths` and `sharedLibPaths` +/// respectively. /// TODO: figure out other options. -MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, - int optLevel); +MLIR_CAPI_EXPORTED MlirExecutionEngine +mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths); /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngine.cpp index 38cf6b2ca..089c29507 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngine.cpp @@ -59,20 +59,26 @@ void mlir::python::populateExecutionEngineSubmodule(py::module &m) { // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine") - .def(py::init<>([](PyModule &module, int optLevel) { - MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module.get(), optLevel); + .def(py::init<>([](PyModule &module, int optLevel, + const std::vector &sharedLibPaths) { + llvm::SmallVector libPaths; + for (const std::string &path : sharedLibPaths) + libPaths.push_back({path.c_str(), path.length()}); + MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( + module.get(), optLevel, libPaths.size(), libPaths.data()); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); return new PyExecutionEngine(executionEngine); }), py::arg("module"), py::arg("opt_level") = 2, + py::arg("shared_libs") = py::list(), "Create a new ExecutionEngine instance for the given Module. The " "module must contain only dialects that can be translated to LLVM. " "Perform transformations and code generation at the optimization " "level `opt_level` if specified, or otherwise at the default " - "level of two (-O2).") + "level of two (-O2). Load a list of libraries specified in " + "`shared_libs`.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule) .def("_testing_release", &PyExecutionEngine::release, diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index dfde38aee..42bacd967 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -17,8 +17,9 @@ using namespace mlir; -extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, - int optLevel) { +extern "C" MlirExecutionEngine +mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths) { static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmPrinter(); @@ -39,13 +40,18 @@ extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, return MlirExecutionEngine{nullptr}; } + SmallVector libPaths; + for (unsigned i = 0; i < static_cast(numPaths); ++i) + libPaths.push_back(sharedLibPaths[i].data); + // Create a transformer to run all LLVM optimization passes at the // specified optimization level. auto llvmOptLevel = static_cast(optLevel); auto transformer = mlir::makeLLVMPassesTransformer( /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); - auto jitOrError = ExecutionEngine::create( - unwrap(op), /*llvmModuleBuilder=*/{}, transformer, llvmOptLevel); + auto jitOrError = + ExecutionEngine::create(unwrap(op), /*llvmModuleBuilder=*/{}, transformer, + llvmOptLevel, libPaths); if (!jitOrError) { consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; From 74f784377b2c017a7100146cfe6f83c23f38e44f Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Tue, 15 Jun 2021 08:35:10 +0000 Subject: [PATCH 059/915] [mlir][linalg][python] Adapt the OpDSL to use scalars. The patch replaces the existing capture functionality by scalar operands that have been introduced by https://reviews.llvm.org/D104109. Scalar operands behave as tensor operands except for the fact that they are not indexed. As a result ScalarDefs can be accessed directly as no indexing expression is needed. The patch only updates the OpDSL. The C++ side is updated by a follow up patch. Differential Revision: https://reviews.llvm.org/D104220 --- .../linalg/opdsl/lang/comprehension.py | 178 ++++++++--------- .../mlir/dialects/linalg/opdsl/lang/config.py | 188 ++++++++---------- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 33 ++- .../dialects/linalg/opdsl/lang/emitter.py | 105 ++++------ .../dialects/linalg/opdsl/lang/scalar_expr.py | 28 +-- .../linalg/opdsl/ops/core_named_ops.py | 12 +- 6 files changed, 234 insertions(+), 310 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 2ac0641a3..fe067d694 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Set, Tuple from mlir import ir as _ir @@ -50,7 +50,7 @@ def visit_affine_exprs(expr): self.visit_tensor_exprs(visit_affine_exprs) return results - def collect_uses(self, uses: Set["TensorUse"]): + def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" def visit_tensor_use(expr): @@ -68,14 +68,14 @@ def visit_index(expr): self.visit_tensor_exprs(visit_index) - def collect_captures(self, captures: Set["CaptureDef"]): - """Collects all CaptureDefs reachable through this expression.""" + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" - def visit_capture_def(expr): - if isinstance(expr, CaptureDef): - captures.add(expr) + def visit_scalar_def(expr): + if isinstance(expr, ScalarDef): + uses.add(expr) - self.visit_tensor_exprs(visit_capture_def) + self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": return PrimFn.add(self, rhs) @@ -101,19 +101,19 @@ class TensorUse(TensorExpression): TensorDef.__setitem__ """ - def __init__(self, tensor_def: "TensorDef", indices: Sequence[AffineExprDef]): - self.tensor_def = tensor_def + def __init__(self, operand_def: "OperandDef", + indices: Sequence[AffineExprDef]): + self.operand_def = operand_def self.indices = tuple(indices) def to_scalar_expression(self) -> ScalarExpression: - assert self.tensor_def.tensor_name is not None - return ScalarArg(self.tensor_def.tensor_name).expr() + return ScalarArg(self.tensor_name).expr() @property def tensor_name(self) -> str: - n = self.tensor_def.tensor_name - assert n is not None, "TensorDef not attached" - return n + name = self.operand_def.name + assert name is not None, "TensorDef not attached" + return name def __iadd__(self, rhs: TensorExpression) -> TensorExpression: return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) @@ -133,40 +133,57 @@ def __repr__(self): return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" -class TensorDef: - """Bookkeeping of a single registered tensor, held in dict by name.""" +class OperandDef: + """Definition of a Tensor or Scalar operand passed to an operation.""" - def __init__(self, - type_var: TypeVar, - *shape: AffineExprDef, - indexing_map: Optional[_ir.AffineMap] = None, - output: bool = False): + def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef], + scalar: bool, output: bool): if not isinstance(type_var, TypeVar): - raise ValueError(f"TensorDef requires a TypeVar. Got: {repr(type_var)}") + raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var self.shape = shape - self.indexing_map = indexing_map + self.scalar = scalar self.output = output - self.tensor_name = None # type: Optional[str] + self.name = None # type: Optional[str] self.registered_index = -1 # type: int - @property - def rank(self) -> int: - """The rank of the tensor.""" - return len(self.shape) - - def attach(self, index: int, tensor_name: str, owner: "LinalgOpDef"): + def attach(self, index: int, name: str, owner: "LinalgOpDef"): if self.owner: - raise ValueError(f"TensorDef already registered with op: {self}") + raise ValueError(f"OperandDef already registered with op: {self}") self.registered_index = index - self.tensor_name = tensor_name + self.name = name self.owner = owner + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + output = "OUTPUT " if self.output else "" + scalar = "SCALAR " if self.scalar else "" + return (f"{self.name}:OperandDef({output}{scalar}" + f"{repr(self.type_var)}, shape={self.shape})") + + +class TensorDef: + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. + """ + + def __init__(self, + type_var: TypeVar, + *shape: AffineExprDef, + output: bool = False): + self.operand_def = OperandDef(type_var, shape, False, output) + def __getitem__(self, dims) -> TensorUse: - assert self.owner, "TensorDef is not attached to an op" + assert self.operand_def.owner, "TensorDef is not attached to an op" state = AffineBuildState( - global_state=self.owner._affine_state, allow_new_symbols=False) + global_state=self.operand_def.owner._affine_state, + allow_new_symbols=False) if not isinstance(dims, tuple): dims = (dims,) # Handle single subscript case. # Special case: (None) is a 0d-scalar use. @@ -179,7 +196,7 @@ def __getitem__(self, dims) -> TensorUse: raise KeyError( "A TensorDef can only be subscripted by a tuple of affine dims") exprs.append(expr_def) - return TensorUse(self, exprs) + return TensorUse(self.operand_def, exprs) def __setitem__(self, dims, value): """Creates a new 1:1 comprehension by binding this tensor to an expression. @@ -192,46 +209,28 @@ def __setitem__(self, dims, value): f"Got: {repr(value)}") use = self[dims] comp = Comprehension((use, value)) - self.owner.comprehensions.append(comp) + self.operand_def.owner.comprehensions.append(comp) - def __hash__(self): - return hash(id(self)) - def __repr__(self): - output = "OUTPUT " if self.output else "" - return (f"{self.tensor_name}:TensorDef({output}{repr(self.type_var)}, " - f"shape={self.shape})") - - -class CaptureDef(TensorExpression): - """Defines an SSA value captured by the operation. +class ScalarDef(TensorExpression): + """Scalar operand definition. - The captured SSA values are not indexed by the indexing_maps of the - structured op (as opposed to memrefs and tensors). A unique name - identifies the captures and an index determines their position the - operation's parameter list. + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. """ def __init__(self, type_var: TypeVar): - if not isinstance(type_var, TypeVar): - raise ValueError(f"CaptureDef requires a TypeVar. Got: {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] - self.type_var = type_var - self.capture_name = None # type: Optional[str] - self.registered_index = -1 # type: int + self.operand_def = OperandDef(type_var, (), True, False) - def attach(self, index: int, capture_name: str, owner: "LinalgOpDef"): - if self.owner: - raise ValueError(f"CaptureDef already registered with op: {self}") - self.registered_index = index - self.capture_name = capture_name - self.owner = owner + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not attached" + return name def to_scalar_expression(self) -> ScalarExpression: - return ScalarCapture(self.capture_name).expr() - - def __repr__(self): - return (f"{self.capture_name}:CaptureDef({repr(self.type_var)})") + return ScalarArg(self.scalar_name).expr() class Comprehension: @@ -472,43 +471,34 @@ def __init__(self, doc: Optional[str] = None): self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) - self.registered_tensors = dict() # type: Dict[str, TensorDef] - self.registered_captures = dict() # type: Dict[str, CaptureDef] + self.registered_operands = dict() # type: Dict[str, OperandDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() @property - def inputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if not t.output] + def outputs(self) -> Sequence[OperandDef]: + return [ + operand for operand in self.registered_operands.values() + if operand.output + ] - @property - def outputs(self) -> Sequence[TensorDef]: - return [t for t in self.registered_tensors.values() if t.output] - - def add_tensor(self, tensor_name: str, tensor: TensorDef): - """Registers a tensor.""" - if tensor_name in self.registered_tensors: - raise ValueError(f"Tensor {tensor_name} is already registered " - f"to {self.registered_tensors['tensor_name']}") - tensor.attach(len(self.registered_tensors), tensor_name, self) - self.registered_tensors[tensor_name] = tensor - - def add_capture(self, capture_name: str, capture: CaptureDef): - """Registers a capture.""" - if capture_name in self.registered_captures: - raise ValueError(f"Capture {capture_name} is already registered " - f"to {self.registered_captures['capture_name']}") - capture.attach(len(self.registered_captures), capture_name, self) - self.registered_captures[capture_name] = capture + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError(f"The operand {name} is already registered " + f"to {self.registered_operands['name']}") + if not operand.output and self.outputs: + raise ValueError(f"The operand {name} is an input registered after " + f"the output {self.outputs[-1]}") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand def __repr__(self): lines = [ f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," ] - for name, tensor in self.registered_tensors.items(): - lines.append(f" {tensor}") - for name, capture in self.registered_captures.items(): - lines.append(f" {capture}") + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") if self.comprehensions: lines[-1] += " {" for comprehension in self.comprehensions: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 9026e2030..6dd86334b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -18,11 +18,7 @@ from .comprehension import * from .yaml_helper import * -__all__ = [ - "LinalgStructuredOpConfig", - "LinalgOpConfig", - "TensorDefConfig", -] +__all__ = ["LinalgStructuredOpConfig", "LinalgOpConfig", "OperandDefConfig"] def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: @@ -43,49 +39,42 @@ def __repr__(self): return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" -class TensorDefConfig(YAMLObject): - """Wrapper around a TensorDef with additional context-bound state.""" - yaml_tag = "LinalgTensorDef" +class OperandDefConfig(YAMLObject): + """Wrapper containing an operand definition with additional state.""" + yaml_tag = "!LinalgOperandDefConfig" - def __init__(self, tensor_def: TensorDef, shape_map: _ir.AffineMap): - self.tensor_def = tensor_def - self.shape_map = shape_map + def __init__(self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property - def usage(self) -> str: - if self.tensor_def.output: - return "output" - else: - return "input" - - def to_yaml_custom_dict(self): - return dict( - name=self.tensor_def.tensor_name, - usage=self.usage, - shape=_serialize_affine_map(self.shape_map), - element_type_var=self.tensor_def.type_var.name, - ) - - def __repr__(self): - return f"Def({self.tensor_def}, shape_map={self.shape_map}, indexing_map={self.indexing_map})" + def name(self) -> str: + return self.operand_def.name + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var -class CaptureDefConfig(YAMLObject): - """Wrapper around a CaptureDef.""" - yaml_tag = "LinalgCaptureDef" - - def __init__(self, capture_def: CaptureDef): - self.capture_def = capture_def + @property + def usage(self) -> str: + if self.operand_def.output: + return "output" + return "input" def to_yaml_custom_dict(self): - return dict( - name=self.capture_def.capture_name, - type_var=self.capture_def.type_var.name, - ) + self_dict = dict(name=self.name) + self_dict["usage"] = self.usage + if not self.operand_def.scalar: + self_dict["shape"] = _serialize_affine_map(self.shape_map) + self_dict["type_var"] = self.type_var.name + return self_dict def __repr__(self): - return f"Def({self.capture_def})" + return (f"OperandDefConfig({self.operand_def}, " + f"shape_map={self.shape_map}, indexing_map={self.indexing_map})") class LinalgIndexingMapsConfig(YAMLObject): @@ -124,67 +113,73 @@ def __init__(self, self.context = context if context is not None else _ir.Context() self.affine_state = AffineBuildState() self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.tensor_args = dict() # type: Dict[TensorDef, TensorDefConfig] - self.capture_args = dict() # type: Dict[CaptureDef, CaptureDefConfig] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] # Compute the ordered set of writes and collect the tensor, capture, and # index uses. - collected_uses = set() - collected_captures = set() + collected_tensor_uses = set() + collected_scalar_uses = set() collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): self.writes.append((write_use, read_use)) for write_use, read_use in self.writes: - collected_uses.add(write_use) - read_use.collect_uses(collected_uses) - read_use.collect_captures(collected_captures) + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) read_use.collect_indices(collected_indices) # Need to add all definitions before uses, so process twice. - for use in collected_uses: - self.add_tensor_arg(use.tensor_def) - for capture in collected_captures: - self.add_capture_arg(capture) - for use in collected_uses: - self.add_use(use) + for use in collected_tensor_uses: + self.add_operand(use.operand_def) + for use in collected_scalar_uses: + self.add_operand(use.operand_def) + for use in collected_tensor_uses: + self.add_tensor_use(use) # Now normalize all defs and uses indexing maps now that full count of # dims and symbols are known. for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for cdef in self.tensor_args.values(): - cdef.shape_map = self._normalize_affine_map( - cdef.shape_map, with_dims=False) + for cdef in self.operands.values(): + if not cdef.operand_def.scalar: + cdef.shape_map = self._normalize_affine_map( + cdef.shape_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. for write_use, _ in self.writes: - write_tensor_def = self.tensor_args[write_use.tensor_def] - if write_tensor_def.indexing_map: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: raise ValueError( - f"Unexpected multi-write to a single tensor: {write_tensor_def}") - write_tensor_def.indexing_map = self.uses[write_use].indexing_map + f"Unexpected multi-write to a single tensor: {write_tensor_config}") + write_tensor_config.indexing_map = self.uses[write_use].indexing_map # For each read use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. for _, read_expr in self.writes: read_uses = set() # type: Set[TensorUse] - read_expr.collect_uses(read_uses) + read_expr.collect_tensor_uses(read_uses) for read_use in read_uses: - read_tensor_def = self.tensor_args[read_use.tensor_def] - if (read_tensor_def.indexing_map and - read_tensor_def.indexing_map != self.uses[read_use].indexing_map): + read_operand_config = self.operands[read_use.operand_def] + if (read_operand_config.indexing_map and + read_operand_config.indexing_map != + self.uses[read_use].indexing_map): raise ValueError( f"Unexpected multi-read of a tensor with different accesses:" - f"{read_tensor_def} vs {read_use}") - read_tensor_def.indexing_map = self.uses[read_use].indexing_map + f"{read_operand_config} vs {read_use}") + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.scalar: + operand_config.indexing_map = self._create_empty_affine_map() # Sanity check that all defs have an indexing map. - assert all(d.indexing_map for d in self.tensor_args.values()), ( - f"Missing indexing map on TensorDef: {self.tensor_args}") + assert all(d.indexing_map for d in self.operands.values()), ( + f"Missing indexing map on OperandConfigDef: {self.operands}") # Collect reduction dims and ensure all the same. all_reduction_dims = set(comprehension.all_reduction_dims) @@ -209,22 +204,10 @@ def __init__(self, ] @property - def ordered_tensor_args(self) -> Sequence[TensorDefConfig]: + def ordered_operands(self) -> Sequence[OperandDefConfig]: return sorted( - self.tensor_args.values(), - key=lambda tdc: tdc.tensor_def.registered_index) - - @property - def ordered_tensor_uses(self) -> Sequence[TensorUseConfig]: - return sorted( - self.uses.values(), - key=lambda tuc: tuc.tensor_use.tensor_def.registered_index) - - @property - def ordered_capture_args(self) -> Sequence[CaptureDefConfig]: - return sorted( - self.capture_args.values(), - key=lambda cdc: cdc.capture_def.registered_index) + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index) @property def ordered_dims(self) -> Sequence[Tuple[str, int]]: @@ -238,7 +221,7 @@ def ordered_dims(self) -> Sequence[Tuple[str, int]]: @property def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [use.indexing_map for use in self.ordered_tensor_uses] + return [d.indexing_map for d in self.ordered_operands] @property def iterator_types(self) -> Sequence[str]: @@ -251,23 +234,25 @@ def get_type(symbolic_name, position): return [get_type(*dim) for dim in self.ordered_dims] - def add_tensor_arg(self, tensor_def: TensorDef): - if tensor_def in self.tensor_args: + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if operand_def.scalar: + self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: local_state = AffineBuildState( global_state=self.affine_state, allow_new_dims=False) exprs = [] - for expr in tensor_def.shape: + for expr in operand_def.shape: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - indexing_map = _ir.AffineMap.get( + shape_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) + def_config = OperandDefConfig(operand_def, shape_map) + self.operands[operand_def] = def_config - def_config = TensorDefConfig(tensor_def, indexing_map) - self.tensor_args[tensor_def] = def_config - - def add_use(self, tensor_use: TensorUse): + def add_tensor_use(self, tensor_use: TensorUse): if tensor_use in self.uses: return with self.context: @@ -285,11 +270,13 @@ def add_use(self, tensor_use: TensorUse): use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config - def add_capture_arg(self, capture_def: CaptureDef): - if capture_def in self.capture_args: - return - def_config = CaptureDefConfig(capture_def) - self.capture_args[capture_def] = def_config + def _create_empty_affine_map(self) -> _ir.AffineMap: + """Create an affine map with an empty range.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list()) def _normalize_affine_map(self, affine_map: _ir.AffineMap, @@ -302,9 +289,7 @@ def _normalize_affine_map(self, exprs=list(affine_map.results)) def to_yaml_custom_dict(self): - self_dict = dict(args=self.ordered_tensor_args) - if self.ordered_capture_args: - self_dict["captures"] = self.ordered_capture_args + self_dict = dict(args=self.ordered_operands) # TODO: Refactor the hierarchy internally when supporting more # than static (preserving this serialized form). self_dict["indexing_maps"] = LinalgIndexingMapsConfig( @@ -315,11 +300,8 @@ def to_yaml_custom_dict(self): def __repr__(self): lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] - lines.append("tensor_args=[") - for def_config in self.ordered_tensor_args: - lines.append(f" {repr(def_config)}") - lines.append("], capture_args=[") - for def_config in self.ordered_capture_args: + lines.append("operands=[") + for def_config in self.ordered_operands: lines.append(f" {repr(def_config)}") lines.append("], indexing_maps=[") for m in self.indexing_maps: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 428eadfe0..191b1b34f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -53,8 +53,8 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs): False, a named form is emitted (which must have been built in to the compiler). """ - op_configs = LinalgOpConfig.from_linalg_op_def(self.model, - context=ir.Context.current) + op_configs = LinalgOpConfig.from_linalg_op_def( + self.model, context=ir.Context.current) if len(op_configs) != 1: # TODO: Support composite ops. @@ -63,8 +63,9 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs): ctx = ir.Context.current linalgDialect = ctx.get_dialect_descriptor("linalg") - fully_qualified_name = 'linalg.' + self.op_name - emit_generic = (emit_generic or not ctx.is_registered_operation(fully_qualified_name)) + fully_qualified_name = "linalg." + self.op_name + emit_generic = ( + emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] if op_config.structured_op: @@ -72,9 +73,9 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs): return emit_generic_structured_op(op_config.structured_op, *args, **kwargs) else: - return emit_named_structured_op( - op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, *args, **kwargs) + return emit_named_structured_op(op_config.structured_op, self.op_name, + self.model.metadata.cpp_class_name, + *args, **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -86,9 +87,8 @@ def linalg_structured_op(dsl_func=None, op_class_name=None) -> DefinedOpCallable: if dsl_func is None: # Curry the keyword args in for delayed application. - return functools.partial(tc_def_op, - op_name=op_name, - op_class_name=op_class_name) + return functools.partial( + tc_def_op, op_name=op_name, op_class_name=op_class_name) # Determine default names by introspecting the function. if op_name is None: op_name = dsl_func.__name__ @@ -96,9 +96,8 @@ def linalg_structured_op(dsl_func=None, # Camel case it. op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - tc_model = LinalgOpDef(name=op_name, - cpp_class_name=op_class_name, - doc=inspect.getdoc(dsl_func)) + tc_model = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. dsl_func_args = list() @@ -106,12 +105,12 @@ def linalg_structured_op(dsl_func=None, for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, TensorDef): - tc_model.add_tensor(param_name, param_default) - elif isinstance(param_default, CaptureDef): - tc_model.add_capture(param_name, param_default) + tc_model.add_operand(param_name, param_default.operand_def) + elif isinstance(param_default, ScalarDef): + tc_model.add_operand(param_name, param_default.operand_def) else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...) or CaptureDef(...): Found {param_name}" + f"TensorDef(...) or ScalarDef(...): Found {param_name}" f": {param_default}") dsl_func_args.append(param_default) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 5538a9e42..2b8b91050 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -29,20 +29,15 @@ def isa(cls: Type, ty: Type): def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value], - captures: Sequence[Value]): - all_arg_defs = op_config.ordered_tensor_args + *ins: Value, outs: Sequence[Value]): + all_arg_defs = op_config.ordered_operands in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] - capture_arg_defs = op_config.ordered_capture_args # Verify outs and captures are sequences. if not isinstance(outs, Sequence): raise ValueError(f"Expected named argument outs to have type Sequence " f"but got {type(outs)}") - if not isinstance(captures, Sequence): - raise ValueError(f"Expected named argument captures to have type Sequence " - f"but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -51,9 +46,6 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, if outs and len(outs) != len(out_arg_defs): raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") - if captures and len(captures) != len(capture_arg_defs): - raise ValueError(f"Expected {len(capture_arg_defs)} captures but got " - f"{len(captures)} for {op_config}") outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) @@ -68,18 +60,10 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, type_mapping["I64"] = IntegerType.get_signless(64) # Extract type vars for input/output based types. - for arg_def, arg_element_type in zip( - in_arg_defs + out_arg_defs, - _get_shaped_element_types_from_values(*ins, *outs)): - _add_type_mapping(arg_def.tensor_def.type_var.name, arg_element_type, - type_mapping) - - # Extract type vars for captures and compute capture argument mapping. - capture_arg_mapping = dict() # type: Dict[str, Value] - for arg_def, capture_value in zip(capture_arg_defs, captures): - _add_type_mapping(arg_def.capture_def.type_var.name, capture_value.type, - type_mapping) - capture_arg_mapping[arg_def.capture_def.capture_name] = capture_value + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs, + _get_types_from_values(*ins, *outs)): + _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) # Emit the generic op. # TODO: Support emission of pure memref form. @@ -94,18 +78,16 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, [StringAttr.get(s) for s in op_config.iterator_types]) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, capture_arg_mapping, indexing_maps_attr, - iterator_types_attr) + type_mapping, indexing_maps_attr, iterator_types_attr, + block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Sequence[Value] = (), - captures: Sequence[Value] = ()): + outs: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs, - captures=captures) + indexing_maps_attr, iterator_types_attr, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -117,16 +99,14 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, library_call=None) # TODO: Make optional. # Construct the body. - block_arg_names = _get_tensor_def_names(*in_arg_defs, *out_arg_defs) - block_arg_types = _get_shaped_element_types_from_values(*ins, *outs) + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - capture_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) - body_builder.yield_outputs(*_get_tensor_def_names(*out_arg_defs)) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) if len(result_types) == 1: return generic_op.result @@ -138,12 +118,10 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, op_class_name: str, *ins: Value, - outs: Sequence[Value] = (), - captures: Sequence[Value] = ()): + outs: Sequence[Value] = ()): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - capture_arg_mapping, indexing_maps_attr, iterator_types_attr = \ - prepare_common_structured_op(op_config, *ins, outs = outs, - captures = captures) + indexing_maps_attr, iterator_types_attr, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -173,11 +151,9 @@ class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], - capture_arg_mapping: Dict[str, Value]): + block_arg_mapping: Dict[str, Value]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping - self.capture_arg_mapping = capture_arg_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -194,13 +170,6 @@ def expression(self, expr: ScalarExpression) -> Value: except KeyError: raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " f"this structured op.") - elif expr.scalar_capture: - try: - return self.capture_arg_mapping[expr.scalar_capture.capture] - except KeyError: - raise ValueError( - f"Capture {expr.scalar_capture.capture} is not bound for " - f"this structured op.") elif expr.scalar_const: value_attr = Attribute.parse(expr.scalar_const.value) return std.ConstantOp(value_attr.type, value_attr).result @@ -229,7 +198,7 @@ def cast(self, type_var_name: str, operand: Value) -> Value: to_type = self.type_mapping[type_var_name] except KeyError: raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mappings.keys()}") + f"expected one of {self.type_mapping.keys()}") if operand.type == to_type: return operand if _is_integer_type(to_type): @@ -300,9 +269,9 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value: def _infer_structured_outs(op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[TensorDefConfig], + in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], - out_arg_defs: Sequence[TensorDefConfig], + out_arg_defs: Sequence[OperandDefConfig], outs: Sequence[Value]): """Infers implicit outs and output types. @@ -319,28 +288,34 @@ def _infer_structured_outs(op_config: LinalgStructuredOpConfig, "structured ops") -def _get_shaped_element_types_from_values(*values: Value) -> Sequence[Type]: +def _get_types_from_values(*values: Value) -> Sequence[Type]: types = [] for v in values: - try: - t = ShapedType(v.type) - except Exception as e: - raise ValueError(f"Expected ShapedType but got {v}") from e - types.append(t.element_type) + types.append(v.type) return types -def _get_tensor_def_names( - *tensor_def_configs: TensorDefConfig) -> Sequence[str]: - return [tdc.tensor_def.tensor_name for tdc in tensor_def_configs] +def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: + return [odc.operand_def.name for odc in operand_configs] -def _add_type_mapping(name: str, type: Type, type_mapping: Dict[str, Type]): +def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type]): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.operand_def.shape: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError(f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name if name in type_mapping: - if type_mapping[name] != type: + if type_mapping[name] != element_or_self_type: raise ValueError(f"Cannot overwrite type mapping {name} = " - f"{type_mapping[name]} by type {type}") - type_mapping[name] = type + f"{type_mapping[name]} by type {element_or_self_type}") + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) def _is_floating_point_type(t: Type) -> bool: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index 2cc426b62..48627bfab 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -22,7 +22,6 @@ "ScalarAssign", "ScalarApplyFn", "ScalarArg", - "ScalarCapture", "ScalarConst", "ScalarIndex", "ScalarExpression", @@ -57,19 +56,6 @@ def __repr__(self): return f"(ScalarArg({self.arg})" -class ScalarCapture: - """A type of ScalarExpression that references a named capture.""" - - def __init__(self, capture: str): - self.capture = capture - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_capture=self) - - def __repr__(self): - return f"(ScalarCapture({self.capture})" - - class ScalarConst: """A type of ScalarExpression representing a constant.""" @@ -116,7 +102,6 @@ class ScalarExpression(YAMLObject): Can be one of: - ScalarApplyFn - ScalarArg - - ScalarCapture - ScalarConst - ScalarIndex - ScalarSymbolicCast @@ -126,18 +111,15 @@ class ScalarExpression(YAMLObject): def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, scalar_arg: Optional[ScalarArg] = None, - scalar_capture: Optional[ScalarCapture] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None, symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_capture) + - bool(scalar_const) + bool(scalar_index) + bool(symbolic_cast)) != 1: - raise ValueError( - "One of 'scalar_apply', 'scalar_arg', 'scalar_capture', 'scalar_const', " - "'scalar_index', 'symbolic_cast' must be specified") + if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) + + bool(scalar_index) + bool(symbolic_cast)) != 1: + raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', " + "'scalar_index', 'symbolic_cast' must be specified") self.scalar_apply = scalar_apply self.scalar_arg = scalar_arg - self.scalar_capture = scalar_capture self.scalar_const = scalar_const self.scalar_index = scalar_index self.symbolic_cast = symbolic_cast @@ -151,8 +133,6 @@ def to_yaml_custom_dict(self): )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) - elif self.scalar_capture: - return dict(scalar_capture=self.scalar_capture.capture) elif self.scalar_const: return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index ad7996345..c6586824a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -75,7 +75,11 @@ def dot( @linalg_structured_op -def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)): +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True)): """Fills the output tensor with pseudo random numbers. The operation generations pseudo random numbers using a linear congruential @@ -85,13 +89,7 @@ def fill_rng_2d(O=TensorDef(T, S.M, S.N, output=True)): and runs them in parallel. The seed operand and the indices of the data element seed the random number generation. The min and max operands limit the range of the generated random numbers. - - Note: The captures are hard-coded till there is capture support on the C++ - side. """ - min = cast(F64, const(-1000)) - max = cast(F64, const(+1000)) - seed = cast(I32, const(42)) multiplier = cast(I32, const(1103515245)) increment = cast(I32, const(12345)) rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment From c361814f160601ad4e6aedfe2bf5a57d5d99ccfc Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Tue, 22 Jun 2021 06:27:28 +0000 Subject: [PATCH 060/915] [mlir][linalg] Adapt FillOp to use a scalar operand. Adapt the FillOp definition to use a scalar operand instead of a capture. This patch is a follow up to https://reviews.llvm.org/D104109. As the input operands are in front of the output operands the patch changes the internal operand order of the FillOp. The pretty printed version of the operation remains unchanged though. The patch also adapts the linalg to standard lowering to ensure the c signature of the FillOp remains unchanged as well. Differential Revision: https://reviews.llvm.org/D104121 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 42 ++++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 0aea4e603..c7ddfb962 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -10,7 +10,7 @@ from _mlir.dialects.linalg import fill_builtin_region -def isa(cls : Type, ty : Type): +def isa(cls: Type, ty: Type): try: cls(ty) return True @@ -21,23 +21,19 @@ def isa(cls : Type, ty : Type): class FillOp: """Extends the linalg.fill op.""" - def __init__(self, - output: Value, - value: Value, - *, - loc=None, - ip=None): + def __init__(self, output: Value, value: Value, *, loc=None, ip=None): results = [] if isa(RankedTensorType, output.type): results = [output.type] - op = self.build_generic(results=results, - operands=[output, value], - attributes=None, - loc=loc, - ip=ip) + op = self.build_generic( + results=results, + operands=[value, output], + attributes=None, + loc=loc, + ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation, [value]) + fill_builtin_region(linalgDialect, self.operation, []) # TODO: self.result is None. When len(results) == 1 we expect it to be # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug # in the generator of _linalg_ops_gen.py where we have: @@ -78,11 +74,12 @@ def __init__(self, attributes["static_sizes"] = ArrayAttr.get( [IntegerAttr.get(i64_type, s) for s in static_size_ints], context=context) - op = self.build_generic(results=[result_type], - operands=operands, - attributes=attributes, - loc=loc, - ip=ip) + op = self.build_generic( + results=[result_type], + operands=operands, + attributes=attributes, + loc=loc, + ip=ip) OpView.__init__(self, op) @@ -91,10 +88,11 @@ class StructuredOpMixin: def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): super().__init__( - self.build_generic(results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip)) + self.build_generic( + results=list(results), + operands=[list(inputs), list(outputs)], + loc=loc, + ip=ip)) def select_opview_mixin(parent_opview_cls): From 1d5282c4d1c2ed65cb98494179843b7bc3c5bfec Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 24 Jun 2021 09:21:12 +0000 Subject: [PATCH 061/915] [mlir][linalg][python] Add attribute support to the OpDSL. Extend the OpDSL with index attributes. After tensors and scalars, index attributes are the third operand type. An index attribute represents a compile-time constant that is limited to index expressions. A use cases are the strides and dilations defined by convolution and pooling operations. The patch only updates the OpDSL. The C++ yaml codegen is updated by a followup patch. Differential Revision: https://reviews.llvm.org/D104711 --- mlir/include/mlir-c/AffineMap.h | 7 ++ mlir/lib/Bindings/Python/IRAffine.cpp | 8 ++ mlir/lib/CAPI/IR/AffineMap.cpp | 9 ++ .../linalg/opdsl/lang/comprehension.py | 84 ++++++++++++----- .../mlir/dialects/linalg/opdsl/lang/config.py | 94 ++++++++++++------- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 31 +++--- .../dialects/linalg/opdsl/lang/emitter.py | 90 +++++++++++++----- .../linalg/opdsl/ops/core_named_ops.py | 13 +++ 8 files changed, 239 insertions(+), 97 deletions(-) diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h index e35b7cc6b..7359b9691 100644 --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -169,6 +169,13 @@ mlirAffineMapGetMajorSubMap(MlirAffineMap affineMap, intptr_t numResults); MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, intptr_t numResults); +/// Apply AffineExpr::replace(`map`) to each of the results and return a new +/// new AffineMap with the new results and the specified number of dims and +/// symbols. +MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapReplace( + MlirAffineMap affineMap, MlirAffineExpr expression, + MlirAffineExpr replacement, intptr_t numResultDims, intptr_t numResultSyms); + /// Returns the simplified affine map resulting from dropping the symbols that /// do not appear in any of the individual maps in `affineMaps`. /// Asserts that all maps in `affineMaps` are normalized to the same number of diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 5d3b790b3..0a2a5666a 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -654,6 +654,14 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }) + .def("replace", + [](PyAffineMap &self, PyAffineExpr &expression, + PyAffineExpr &replacement, intptr_t numResultDims, + intptr_t numResultSyms) { + MlirAffineMap affineMap = mlirAffineMapReplace( + self, expression, replacement, numResultDims, numResultSyms); + return PyAffineMap(self.getContext(), affineMap); + }) .def_property_readonly( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp index e0c07afc3..85557bc57 100644 --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -138,6 +138,15 @@ MlirAffineMap mlirAffineMapGetMinorSubMap(MlirAffineMap affineMap, return wrap(unwrap(affineMap).getMinorSubMap(numResults)); } +MlirAffineMap mlirAffineMapReplace(MlirAffineMap affineMap, + MlirAffineExpr expression, + MlirAffineExpr replacement, + intptr_t numResultDims, + intptr_t numResultSyms) { + return wrap(unwrap(affineMap).replace(unwrap(expression), unwrap(replacement), + numResultDims, numResultSyms)); +} + void mlirAffineMapCompressUnusedSymbols( MlirAffineMap *affineMaps, intptr_t size, void *result, void (*populateResult)(void *res, intptr_t idx, MlirAffineMap m)) { diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index fe067d694..2b2f57248 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -9,6 +9,7 @@ """ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from enum import Enum from mlir import ir as _ir @@ -133,18 +134,31 @@ def __repr__(self): return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" +class OperandKind(Enum): + InputTensor = 0 + Scalar = 1 + OutputTensor = 2 + Attribute = 3 + + class OperandDef: - """Definition of a Tensor or Scalar operand passed to an operation.""" + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ - def __init__(self, type_var: TypeVar, shape: Sequence[AffineExprDef], - scalar: bool, output: bool): + def __init__(self, + kind: OperandKind, + type_var: TypeVar, + size_exprs: Optional[Sequence[AffineExprDef]] = None): if not isinstance(type_var, TypeVar): - raise ValueError(f"OperandDef requires a TypeVar. Got: {repr(type_var)}") + raise ValueError( + f"OperandDef requires a TypeVar but got {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var - self.shape = shape - self.scalar = scalar - self.output = output + self.size_exprs = size_exprs + self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -159,10 +173,8 @@ def __hash__(self): return hash(id(self)) def __repr__(self): - output = "OUTPUT " if self.output else "" - scalar = "SCALAR " if self.scalar else "" - return (f"{self.name}:OperandDef({output}{scalar}" - f"{repr(self.type_var)}, shape={self.shape})") + return (f"{self.name}:OperandDef(kind={self.kind.name}, " + f"type={repr(self.type_var)}, size_exprs={self.size_exprs})") class TensorDef: @@ -170,14 +182,17 @@ class TensorDef: Tensor operands are indexed using the associated indexing_map when forwarded to the body of the structured op. A unique name identifies the tensor operands - and an index determines their position in the operation's parameter list. + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. """ def __init__(self, type_var: TypeVar, *shape: AffineExprDef, output: bool = False): - self.operand_def = OperandDef(type_var, shape, False, output) + kind = OperandKind.OutputTensor if output else OperandKind.InputTensor + self.operand_def = OperandDef(kind, type_var, size_exprs=shape) def __getitem__(self, dims) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" @@ -221,7 +236,7 @@ class ScalarDef(TensorExpression): """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(type_var, (), True, False) + self.operand_def = OperandDef(OperandKind.Scalar, type_var) @property def scalar_name(self) -> str: @@ -233,6 +248,22 @@ def to_scalar_expression(self) -> ScalarExpression: return ScalarArg(self.scalar_name).expr() +class AttributeDef: + """Index Attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values. + """ + yaml_tag = "!LinalgAttributeDef" + + def __init__(self, *sizes: SymbolDef): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " + f"{type(sizes)}") + self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) + + class Comprehension: """Represents a single comprehension.""" @@ -303,7 +334,7 @@ class ReduceFnType: def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): """Initializes the ReduceFn with a primitive function and dims.""" if not isinstance(operator, PrimFnType): - raise ValueError(f"Reduce expected a Prim operator. Got: {operator}") + raise ValueError(f"Reduce expected a Prim operator but got {operator}") self.operator = operator self.reduce_dims = tuple(reduce_dims) @@ -353,7 +384,7 @@ def __init__(self, value: Any): self.value = str( _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) else: - raise ValueError(f"const requires int or float. Got: {type(value)}") + raise ValueError(f"const requires int or float but got {type(value)}") def to_scalar_expression(self) -> ScalarExpression: return ScalarConst(self.value).expr() @@ -475,21 +506,22 @@ def __init__(self, self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() - @property - def outputs(self) -> Sequence[OperandDef]: - return [ - operand for operand in self.registered_operands.values() - if operand.output - ] - def add_operand(self, name: str, operand: OperandDef): """Registers an operand.""" if name in self.registered_operands: raise ValueError(f"The operand {name} is already registered " f"to {self.registered_operands['name']}") - if not operand.output and self.outputs: - raise ValueError(f"The operand {name} is an input registered after " - f"the output {self.outputs[-1]}") + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + registered_kinds = [ + operand.kind.value for operand in self.registered_operands.values() + ] + if registered_kinds: + maximum = max(registered_kinds) + if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: + raise ValueError( + f"The operand {name} of kind {operand.kind.name} is registered " + f"after an operand of kind {OperandKind(maximum).name}") operand.attach(len(self.registered_operands), name, self) self.registered_operands[name] = operand diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 6dd86334b..773bd8763 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -45,9 +45,11 @@ class OperandDefConfig(YAMLObject): def __init__(self, operand_def: OperandDef, - shape_map: Optional[_ir.AffineMap] = None): + shape_map: Optional[_ir.AffineMap] = None, + attribute_map: Optional[_ir.AffineMap] = None): self.operand_def = operand_def self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.attribute_map = attribute_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property @@ -60,21 +62,25 @@ def type_var(self) -> TypeVar: @property def usage(self) -> str: - if self.operand_def.output: - return "output" - return "input" + if self.operand_def.kind == OperandKind.Attribute: + return "IndexAttribute" + if self.operand_def.kind == OperandKind.OutputTensor: + return "OutputOperand" + return "InputOperand" def to_yaml_custom_dict(self): - self_dict = dict(name=self.name) - self_dict["usage"] = self.usage - if not self.operand_def.scalar: - self_dict["shape"] = _serialize_affine_map(self.shape_map) - self_dict["type_var"] = self.type_var.name + self_dict = dict( + name=self.name, usage=self.usage, type_var=self.type_var.name) + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.attribute_map: + self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map) return self_dict def __repr__(self): return (f"OperandDefConfig({self.operand_def}, " - f"shape_map={self.shape_map}, indexing_map={self.indexing_map})") + f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, " + f"indexing_map={self.indexing_map})") class LinalgIndexingMapsConfig(YAMLObject): @@ -109,6 +115,7 @@ class LinalgStructuredOpConfig(YAMLObject): def __init__(self, comprehension: Comprehension, + registered_operands: Sequence[OperandDef], context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() self.affine_state = AffineBuildState() @@ -131,22 +138,33 @@ def __init__(self, read_use.collect_scalar_uses(collected_scalar_uses) read_use.collect_indices(collected_indices) - # Need to add all definitions before uses, so process twice. + # Collect all attribute definitions + collected_attr_defs = list() + for operand in registered_operands: + if operand.kind == OperandKind.Attribute: + collected_attr_defs.append(operand) + + # Add all definitions before uses, so process twice. for use in collected_tensor_uses: self.add_operand(use.operand_def) for use in collected_scalar_uses: self.add_operand(use.operand_def) + for definition in collected_attr_defs: + self.add_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) - # Now normalize all defs and uses indexing maps now that full count of - # dims and symbols are known. + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for cdef in self.operands.values(): - if not cdef.operand_def.scalar: - cdef.shape_map = self._normalize_affine_map( - cdef.shape_map, with_dims=False) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False) + if operand_config.attribute_map: + operand_config.attribute_map = self._normalize_affine_map( + operand_config.attribute_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. @@ -174,12 +192,16 @@ def __init__(self, # Set the indexing map of all scalar uses to the empty map. for operand_config in self.operands.values(): - if operand_config.operand_def.scalar: - operand_config.indexing_map = self._create_empty_affine_map() + if operand_config.operand_def.kind == OperandKind.Scalar: + operand_config.indexing_map = self._get_scalar_map() - # Sanity check that all defs have an indexing map. - assert all(d.indexing_map for d in self.operands.values()), ( - f"Missing indexing map on OperandConfigDef: {self.operands}") + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.kind == OperandKind.Attribute: + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError(f"Failed to compute an indexing map for operand " + f"{operand.name}") # Collect reduction dims and ensure all the same. all_reduction_dims = set(comprehension.all_reduction_dims) @@ -189,7 +211,7 @@ def __init__(self, f"dims. Got: {all_reduction_dims}") self.reduction_dims = next(iter(all_reduction_dims)) - # Check the index dimension exists and resolve + # Check the index dimension exists and resolve. for index in collected_indices: if index.dim_def.dimname not in self.affine_state.all_dims: raise ValueError( @@ -221,7 +243,7 @@ def ordered_dims(self) -> Sequence[Tuple[str, int]]: @property def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [d.indexing_map for d in self.ordered_operands] + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] @property def iterator_types(self) -> Sequence[str]: @@ -237,20 +259,24 @@ def get_type(symbolic_name, position): def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if operand_def.scalar: + if operand_def.kind == OperandKind.Scalar: self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: local_state = AffineBuildState( global_state=self.affine_state, allow_new_dims=False) exprs = [] - for expr in operand_def.shape: + for expr in operand_def.size_exprs: exprs.append(expr.build(state=local_state)) assert local_state.local_dim_count == 0 - shape_map = _ir.AffineMap.get( + affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - def_config = OperandDefConfig(operand_def, shape_map) - self.operands[operand_def] = def_config + if operand_def.kind == OperandKind.Attribute: + self.operands[operand_def] = OperandDefConfig( + operand_def, attribute_map=affine_map) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map) def add_tensor_use(self, tensor_use: TensorUse): if tensor_use in self.uses: @@ -261,7 +287,6 @@ def add_tensor_use(self, tensor_use: TensorUse): exprs = [] for expr in tensor_use.indices: exprs.append(expr.build(state=local_state)) - assert local_state.local_symbol_count == 0 indexing_map = _ir.AffineMap.get( dim_count=local_state.dim_count, symbol_count=local_state.symbol_count, @@ -270,8 +295,8 @@ def add_tensor_use(self, tensor_use: TensorUse): use_config = TensorUseConfig(tensor_use, indexing_map) self.uses[tensor_use] = use_config - def _create_empty_affine_map(self) -> _ir.AffineMap: - """Create an affine map with an empty range.""" + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" with self.context: return _ir.AffineMap.get( dim_count=self.affine_state.dim_count, @@ -345,8 +370,9 @@ def from_linalg_op_def( return [ LinalgOpConfig( tc_op_def.metadata, - structured_op=LinalgStructuredOpConfig(tc_op_def.comprehensions[0], - context)), + structured_op=LinalgStructuredOpConfig( + tc_op_def.comprehensions[0], + tc_op_def.registered_operands.values(), context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 191b1b34f..6dbda1bb7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -44,15 +44,20 @@ def __init__(self, op_name: str, model: LinalgOpDef): self.op_name = op_name self.model = model - def __call__(self, *args, emit_generic: bool = False, **kwargs): + def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following - are interpreted here: + keyword argument is interpreted here: emit_generic: Emits a generic form as appropriate (default True). If False, a named form is emitted (which must have been built in to the compiler). """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError(f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}") + op_configs = LinalgOpConfig.from_linalg_op_def( self.model, context=ir.Context.current) @@ -70,12 +75,16 @@ def __call__(self, *args, emit_generic: bool = False, **kwargs): op_config = op_configs[0] if op_config.structured_op: if emit_generic: - return emit_generic_structured_op(op_config.structured_op, *args, - **kwargs) + return emit_generic_structured_op( + op_config.structured_op, *ins, outs=outs, **kwargs) else: - return emit_named_structured_op(op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, - *args, **kwargs) + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.model.metadata.cpp_class_name, + *ins, + outs=outs, + **kwargs) raise NotImplementedError( f"Emission of linalg op type not supported: {op_config}") @@ -104,14 +113,12 @@ def linalg_structured_op(dsl_func=None, sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, TensorDef): - tc_model.add_operand(param_name, param_default.operand_def) - elif isinstance(param_default, ScalarDef): + if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)): tc_model.add_operand(param_name, param_default.operand_def) else: raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...) or ScalarDef(...): Found {param_name}" - f": {param_default}") + f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " + f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) # Invoke the DSL func to finish populating the model. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 2b8b91050..f6fb0cc7d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -13,6 +13,7 @@ from .scalar_expr import * from .config import * +import numpy as np __all__ = [ "emit_generic_structured_op", @@ -29,12 +30,14 @@ def isa(cls: Type, ty: Type): def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value]): + *ins: Value, outs: Sequence[Value], + **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "input"] - out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "output"] + in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] + out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] + attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] - # Verify outs and captures are sequences. + # Verify outs is a sequence. if not isinstance(outs, Sequence): raise ValueError(f"Expected named argument outs to have type Sequence " f"but got {type(outs)}") @@ -47,6 +50,40 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") + # Compute a replacement list for all attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for attr in attr_arg_defs: + if attr.name not in attrs: + raise ValueError(f"Expected named argument for the attribute {attr.name}") + attribute_values = attrs.get(attr.name) + if not all(isinstance(value, int) for value in attribute_values): + raise ValueError(f"Attribute {attr.name} needs to be of type " + f"Sequence[int] but got {type(attribute_values)}") + results = attr.attribute_map.results # type: AffineExprList + if len(attribute_values) != len(results): + raise ValueError(f"Attribute {attr.name} has length {len(results)} " + f"but got {len(attribute_values)} values") + for expr, value in zip(results, attribute_values): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, + Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError(f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}") + outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, out_arg_defs, outs) @@ -67,27 +104,28 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, # Emit the generic op. # TODO: Support emission of pure memref form. - indexing_maps_attr = ArrayAttr.get([ - AffineMapAttr.get(am) - # TODO: linalg verification does not currently allow symbols. - # Compress them for now. - for am in AffineMap.compress_unused_symbols(op_config.indexing_maps, - Context.current) - ]) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) + # Compute a dictionary storing all index attributes. + index_attributes = {} # type: Dict[str, DenseElementAttr] + for attr in attr_arg_defs: + attribute_values = attrs.get(attr.name) + array = np.array(attribute_values, dtype=np.int64) + index_attributes[attr.name] = DenseElementsAttr.get(array) + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, - block_arg_types) + index_attributes, block_arg_types) -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, - outs: Sequence[Value] = ()): +def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, + outs: Sequence[Value], **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) generic_op = linalg.GenericOp( result_tensors=result_types, @@ -114,14 +152,12 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, return generic_op.results -def emit_named_structured_op(op_config: LinalgStructuredOpConfig, - op_name: str, - op_class_name: str, - *ins: Value, - outs: Sequence[Value] = ()): +def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, + op_class_name: str, *ins: Value, + outs: Sequence[Value], **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs) + indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. ctx = Context.current @@ -141,6 +177,10 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, "linalg.memoized_indexing_maps"] = indexing_maps_attr # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. + # Additionally set all named attributes. + for name, value in index_attributes.items(): + named_op.operation.attributes[name] = value + if len(result_types) == 1: return named_op.result else: @@ -304,7 +344,7 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, block_arg_types: Sequence[Type]): element_or_self_type = operand_type # Get the element type for tensor operands and the type itself for scalars. - if operand_config.operand_def.shape: + if operand_config.shape_map: try: element_or_self_type = ShapedType(operand_type).element_type except Exception as e: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index c6586824a..fe8bfc501 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -74,6 +74,19 @@ def dot( C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) +@linalg_structured_op +def depthwise_conv_2d_input_nhwc_filter_hwc_poly( + I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """A depth-wise 2-D convolution operation.""" + O[D.n, D.oh, D.ow, D.c] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c]) * cast(U, K[D.kh, D.kw, D.c]) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 16edc653c40ed3fc8dc472f9dfe5a347eb00dc64 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 24 Jun 2021 13:45:18 +0000 Subject: [PATCH 062/915] [mlir][linalg][python] Add shape-only tensor support to OpDSL. Add an index_dim annotation to specify the shape to loop mapping of shape-only tensors. A shape-only tensor serves is not accessed withing the body of the operation but is required to span the iteration space of certain operations such as pooling. Differential Revision: https://reviews.llvm.org/D104767 --- .../linalg/opdsl/lang/comprehension.py | 21 ++++++++++++--- .../mlir/dialects/linalg/opdsl/lang/config.py | 27 ++++++++++++++++++- .../linalg/opdsl/ops/core_named_ops.py | 22 ++++++++++++++- 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 2b2f57248..e89885e97 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -151,13 +151,15 @@ class OperandDef: def __init__(self, kind: OperandKind, type_var: TypeVar, - size_exprs: Optional[Sequence[AffineExprDef]] = None): + size_exprs: Optional[Sequence[AffineExprDef]] = None, + index_dims: Optional[Sequence[DimDef]] = None): if not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var self.size_exprs = size_exprs + self.index_dims = index_dims self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -174,7 +176,8 @@ def __hash__(self): def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " - f"type={repr(self.type_var)}, size_exprs={self.size_exprs})") + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " + f"index_dims={self.index_dims})") class TensorDef: @@ -184,15 +187,25 @@ class TensorDef: to the body of the structured op. A unique name identifies the tensor operands and an index determines their position in the operation's parameter list. A tensor definition takes type, a shape, and an optional flag to mark output - tensors. + tensors. Additionally, a tuple of index dimensions may be used to map the + tensor to the loop dimensions of the operation. This mapping is needed to + compute the indexing map of shape-only tensors that have no uses. """ def __init__(self, type_var: TypeVar, *shape: AffineExprDef, + index_dims: Optional[Sequence[DimDef]] = None, output: bool = False): + if index_dims and len(shape) != len(index_dims): + raise ValueError(f"Expected the shape rank {len(shape)} to match the " + f"number of index_dims {len(index_dims)}") + if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): + raise ValueError(f"TensorDef requires index dims of type DimDef but " + f"got {type(index_dims)}") kind = OperandKind.OutputTensor if output else OperandKind.InputTensor - self.operand_def = OperandDef(kind, type_var, size_exprs=shape) + self.operand_def = OperandDef( + kind, type_var, size_exprs=shape, index_dims=index_dims) def __getitem__(self, dims) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 773bd8763..78e6f1d6a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -138,12 +138,18 @@ def __init__(self, read_use.collect_scalar_uses(collected_scalar_uses) read_use.collect_indices(collected_indices) - # Collect all attribute definitions + # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: if operand.kind == OperandKind.Attribute: collected_attr_defs.append(operand) + # Collect all tensors with manual indexing annotation. + collected_index_defs = list() + for operand in registered_operands: + if operand.index_dims: + collected_index_defs.append(operand) + # Add all definitions before uses, so process twice. for use in collected_tensor_uses: self.add_operand(use.operand_def) @@ -151,6 +157,10 @@ def __init__(self, self.add_operand(use.operand_def) for definition in collected_attr_defs: self.add_operand(definition) + for definition in collected_index_defs: + if definition not in self.operands: + self.add_operand(definition) + self.add_indexed_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) @@ -158,6 +168,9 @@ def __init__(self, # symbols are known. for cuse in self.uses.values(): cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for definition in collected_index_defs: + self.operands[definition].indexing_map = self._normalize_affine_map( + self.operands[definition].indexing_map) for operand_config in self.operands.values(): if operand_config.shape_map: operand_config.shape_map = self._normalize_affine_map( @@ -278,6 +291,18 @@ def add_operand(self, operand_def: OperandDef): self.operands[operand_def] = OperandDefConfig( operand_def, shape_map=affine_map) + def add_indexed_operand(self, operand_def: OperandDef): + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False) + exprs = [] + for expr in operand_def.index_dims: + exprs.append(expr.build(state=local_state)) + self.operands[operand_def].indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs) + def add_tensor_use(self, tensor_use: TensorUse): if tensor_use in self.uses: return diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index fe8bfc501..253fca4b4 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -81,12 +81,32 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly( O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - """A depth-wise 2-D convolution operation.""" + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) +@linalg_structured_op +def pooling_nhwc_sum_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[D.n, D.oh, D.ow, D.c] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 6ff4bd5a4563dd3bdd74015dd41858c5cc3e9430 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 28 Jun 2021 07:30:02 +0000 Subject: [PATCH 063/915] [mlir][linalg] Remove the StructuredOp capture mechanism. After https://reviews.llvm.org/D104109, structured ops support scalar inputs. As a result, the capture mechanism meant to pass non-shaped parameters got redundant. The patch removes the capture semantics after the FillOp migrated to use scalar operands https://reviews.llvm.org/D104121. Differential Revision: https://reviews.llvm.org/D104785 --- mlir/include/mlir-c/Dialect/Linalg.h | 4 +--- mlir/lib/Bindings/Python/DialectLinalg.cpp | 11 +++-------- mlir/lib/CAPI/Dialect/Linalg.cpp | 9 ++------- mlir/python/mlir/dialects/_linalg_ops_ext.py | 2 +- 4 files changed, 7 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 6e20eec16..27f2f7bc8 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -18,11 +18,9 @@ extern "C" { #endif /// Apply the special region builder for the builtin named Linalg op. -/// The list of `capture` MlirValue is passed as-is to the region builder. /// Assert that `op` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op, - intptr_t n, MlirValue const *mlirCaptures); +mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index dfac96db7..a2a54249e 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -21,15 +21,10 @@ using namespace mlir::python; void mlir::python::populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op, py::list captures) { - llvm::SmallVector mlirOperands; - mlirOperands.reserve(captures.size()); - for (auto v : captures) - mlirOperands.push_back(py::cast(v)->get()); - mlirLinalgFillBuiltinNamedOpRegion( - dialect.get(), op.get(), mlirOperands.size(), mlirOperands.data()); + [](PyDialectDescriptor &dialect, PyOperation &op) { + mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); }, - py::arg("dialect"), py::arg("op"), py::arg("captures") = py::list(), + py::arg("dialect"), py::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index be0d54488..902599f3b 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -16,13 +16,8 @@ using namespace mlir::linalg; /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp, intptr_t n, - MlirValue const *mlirCaptures) { + MlirOperation mlirOp) { Operation *op = unwrap(mlirOp); - SmallVector captures; - captures.reserve(n); - for (unsigned idx = 0; idx < n; ++idx) - captures.push_back(unwrap(mlirCaptures[idx])); LinalgDialect::RegionBuilderFunType fun = static_cast(unwrap(linalgDialect)) @@ -41,7 +36,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); b.setInsertionPointToStart(body); - fun(b, *body, captures); + fun(b, *body); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index c7ddfb962..bce4e08ae 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -33,7 +33,7 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None): ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation, []) + fill_builtin_region(linalgDialect, self.operation) # TODO: self.result is None. When len(results) == 1 we expect it to be # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug # in the generator of _linalg_ops_gen.py where we have: From a448170451a82306bff92e780728c9e057a03261 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 30 Jun 2021 08:59:22 +0000 Subject: [PATCH 064/915] [mlir][linalg][python] Explicit shape and dimension order in OpDSL. Extend the OpDSL syntax with an optional `domain` function to specify an explicit dimension order. The extension is needed to provide more control over the dimension order instead of deducing it implicitly depending on the formulation of the tensor comprehension. Additionally, the patch also ensures the symbols are ordered according to the operand definitions of the operation. Differential Revision: https://reviews.llvm.org/D105117 --- .../linalg/opdsl/lang/comprehension.py | 18 ++++--- .../mlir/dialects/linalg/opdsl/lang/config.py | 54 +++++++++++++++---- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 8 +++ .../linalg/opdsl/ops/core_named_ops.py | 7 +++ 4 files changed, 70 insertions(+), 17 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index e89885e97..1f9230de3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -32,13 +32,13 @@ def visit_tensor_exprs(self, callback): """Visits all tensor expression reachable by the expression.""" callback(self) - def _get_all_dim_defs(self) -> Set[DimDef]: - """Recursively gets all DimDef affine expressions that are referenced.""" + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" results = set() def visit_dim_def(dim_def): if isinstance(dim_def, DimDef): - results.add(dim_def) + uses.add(dim_def) def visit_affine_exprs(expr): if isinstance(expr, TensorUse): @@ -49,7 +49,6 @@ def visit_affine_exprs(expr): ind.visit_affine_exprs(visit_dim_def) self.visit_tensor_exprs(visit_affine_exprs) - return results def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" @@ -126,8 +125,10 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: reduced into. Any indices referenced on the rhs and not in self are considered reduction dims and will be ordered as encountered on the rhs. """ - rhs_dims = rhs._get_all_dim_defs() - lhs_dims = self._get_all_dim_defs() + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) return rhs_dims - lhs_dims def __repr__(self): @@ -202,7 +203,7 @@ def __init__(self, f"number of index_dims {len(index_dims)}") if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): raise ValueError(f"TensorDef requires index dims of type DimDef but " - f"got {type(index_dims)}") + f"got {index_dims}") kind = OperandKind.OutputTensor if output else OperandKind.InputTensor self.operand_def = OperandDef( kind, type_var, size_exprs=shape, index_dims=index_dims) @@ -273,7 +274,7 @@ class AttributeDef: def __init__(self, *sizes: SymbolDef): if any(not isinstance(size, SymbolDef) for size in sizes): raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " - f"{type(sizes)}") + f"{sizes}") self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) @@ -516,6 +517,7 @@ def __init__(self, self.metadata = OpMetadataDef( name=name, cpp_class_name=cpp_class_name, doc=doc) self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] self.comprehensions = list() # type: List[Comprehension] self._affine_state = AffineBuildState() diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 78e6f1d6a..f6d5248ea 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -115,6 +115,7 @@ class LinalgStructuredOpConfig(YAMLObject): def __init__(self, comprehension: Comprehension, + domain: Sequence[DimDef], registered_operands: Sequence[OperandDef], context: Optional[_ir.Context] = None): self.context = context if context is not None else _ir.Context() @@ -123,10 +124,11 @@ def __init__(self, self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - # Compute the ordered set of writes and collect the tensor, capture, and - # index uses. + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. collected_tensor_uses = set() collected_scalar_uses = set() + collected_dim_uses = set() collected_indices = set() for write_use, read_use in zip(comprehension.definitions, comprehension.values): @@ -136,8 +138,28 @@ def __init__(self, collected_tensor_uses.add(write_use) read_use.collect_tensor_uses(collected_tensor_uses) read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) read_use.collect_indices(collected_indices) + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if (len(domain) != len(collected_dim_uses) or + any(dim not in collected_dim_uses for dim in domain)): + raise ValueError(f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}") + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False) + for dim in domain: + dim.build(state=local_state) + # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: @@ -148,18 +170,32 @@ def __init__(self, collected_index_defs = list() for operand in registered_operands: if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError(f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses.") collected_index_defs.append(operand) - # Add all definitions before uses, so process twice. + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() for use in collected_tensor_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for use in collected_scalar_uses: - self.add_operand(use.operand_def) + all_operand_defs.append(use.operand_def) for definition in collected_attr_defs: - self.add_operand(definition) + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. for definition in collected_index_defs: - if definition not in self.operands: - self.add_operand(definition) self.add_indexed_operand(definition) for use in collected_tensor_uses: self.add_tensor_use(use) @@ -396,7 +432,7 @@ def from_linalg_op_def( LinalgOpConfig( tc_op_def.metadata, structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], + tc_op_def.comprehensions[0], tc_op_def.domain, tc_op_def.registered_operands.values(), context)), ] diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 6dbda1bb7..1b42b5767 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -132,3 +132,11 @@ def linalg_structured_op(dsl_func=None, def implements(*interfaces: OpInterfaceDef): current_op_def().metadata.implements.extend(interfaces) + + +def domain(*dimensions: DimDef): + if current_op_def().domain: + raise ValueError(f"Expected only one set of domain dimensions per operator") + if any(not isinstance(dim, DimDef) for dim in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 253fca4b4..586710927 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -16,6 +16,7 @@ def matmul( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.m, D.n, D.k) implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @@ -30,6 +31,7 @@ def batch_matmul( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) @@ -44,6 +46,7 @@ def matvec( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.m, D.n) implements(ContractionOpInterface) x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) @@ -58,6 +61,7 @@ def vecmat( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.n, D.m) implements(ContractionOpInterface) x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) @@ -86,6 +90,7 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) @@ -103,6 +108,7 @@ def pooling_nhwc_sum_poly( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -123,6 +129,7 @@ def fill_rng_2d( element seed the random number generation. The min and max operands limit the range of the generated random numbers. """ + domain(D.m, D.n) multiplier = cast(I32, const(1103515245)) increment = cast(I32, const(12345)) rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment From 77e23ac311df0ac14d5efa05a162930ea6f8503d Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Tue, 22 Jun 2021 12:50:10 -0700 Subject: [PATCH 065/915] Add linalg.batch_matvec named op Similarly to batch_mat vec outer most dim is a batching dim and this op does |b| matrix-vector-products : C[b, i] = sum_k(A[b, i, k] * B[b, k]) Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D104739 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 586710927..561cd2e7d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -66,6 +66,21 @@ def vecmat( x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) +@linalg_structured_op +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True)): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k]) + + @linalg_structured_op def dot( A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): From 0b531e36c4d185de5f93d84168d200e9e4b15d08 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Wed, 30 Jun 2021 16:03:19 -0700 Subject: [PATCH 066/915] Add linalg.mmt4d named op This op performs matrix-matrix-transpose multiplication of 4-d inputs as the following: ``` C[m1, n1, m0, n0] = sum_{k1, k0}(A[m1, k1, m0, k0] * B[n1, k1, n0, k0]) ``` Reviewed By: Benoit Differential Revision: https://reviews.llvm.org/D105244 --- .../linalg/opdsl/ops/core_named_ops.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 561cd2e7d..095d94956 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -21,6 +21,26 @@ def matmul( C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, + output=True)): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. + + Differences from linalg.matmul: + * The right hand side is transposed, whence the 't' in 'mmt'. + * The input and output tensors have a 4D shape instead of a 2D shape. They + are interpreted as 2D matrices with one level of 2D tile subdivision, + whence the 2+2=4 dimensions. The inner tile dimensions are identified with + '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads + as: MxK tiles, each of shape M0xK0. + """ + domain(D.m, D.m0, D.n, D.n0, D.k, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + + @linalg_structured_op def batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), From 70ccc2024006e435893e7dc191fd8ad2227cecf3 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Fri, 2 Jul 2021 06:45:34 +0000 Subject: [PATCH 067/915] [mlir][linalg][python] Add max operation in OpDSL Add the max operation to the OpDSL and introduce a max pooling operation to test the implementation. As MLIR has no builtin max operation, the max function is lowered to a compare and select pair. Differential Revision: https://reviews.llvm.org/D105203 --- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 12 ++++++++++++ .../linalg/opdsl/ops/core_named_ops.py | 18 ++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index f6fb0cc7d..9489dec52 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -307,6 +307,18 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value: return std.MulIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") + def _eval_max(self, lhs: Value, rhs: Value) -> Value: + i1 = IntegerType.get_signless(1) + if _is_floating_point_type(lhs.type): + ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) + cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) + cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result + raise NotImplementedError("Unsupported 'max' operand: {lhs}") + def _infer_structured_outs(op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 095d94956..04c950e0a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -148,6 +148,24 @@ def pooling_nhwc_sum_poly( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) +@linalg_structured_op +def pooling_nhwc_max_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 693072f3ff6ccca861e66cb3ea10eb38dc8dbd23 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Fri, 2 Jul 2021 16:08:22 +0000 Subject: [PATCH 068/915] [mlir][linalg][python] Add min operation in OpDSL. Add the min operation to OpDSL and introduce a min pooling operation to test the implementation. The patch is a sibling of the max operation patch https://reviews.llvm.org/D105203 and the min operation is again lowered to a compare and select pair. Differential Revision: https://reviews.llvm.org/D105345 --- .../linalg/opdsl/lang/comprehension.py | 2 ++ .../dialects/linalg/opdsl/lang/emitter.py | 26 +++++++++++++++---- .../linalg/opdsl/ops/core_named_ops.py | 18 +++++++++++++ 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 1f9230de3..66d7510b6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -339,6 +339,7 @@ class PrimFn: log = PrimFnType("log") mul = PrimFnType("mul") max = PrimFnType("max") + min = PrimFnType("min") sub = PrimFnType("sub") @@ -364,6 +365,7 @@ class ReduceFn: add = PrimFn.add.reduce mul = PrimFn.mul.reduce max = PrimFn.max.reduce + min = PrimFn.min.reduce class PrimApply(TensorExpression): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 9489dec52..61d226058 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -308,17 +308,23 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value: raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: - i1 = IntegerType.get_signless(1) if _is_floating_point_type(lhs.type): ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - cond = std.CmpFOp(i1, ogt_attr, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result + return _emit_cmpf_and_select(lhs, rhs, ogt_attr) if _is_integer_type(lhs.type) or _is_index_type(lhs.type): sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - cond = std.CmpIOp(i1, sgt_attr, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result + return _emit_cmpi_and_select(lhs, rhs, sgt_attr) raise NotImplementedError("Unsupported 'max' operand: {lhs}") + def _eval_min(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) + return _emit_cmpf_and_select(lhs, rhs, olt_attr) + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) + return _emit_cmpi_and_select(lhs, rhs, slt_attr) + raise NotImplementedError("Unsupported 'min' operand: {lhs}") + def _infer_structured_outs(op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], @@ -397,3 +403,13 @@ def _get_floating_point_width(t: Type) -> int: if BF16Type.isinstance(t): return 16 raise NotImplementedError(f"Unhandled floating point type switch {t}") + + +def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: + cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result + + +def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: + cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result + return std.SelectOp(lhs.type, cond, lhs, rhs).result diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 04c950e0a..a37e1944c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -166,6 +166,24 @@ def pooling_nhwc_max_poly( D.c])) +@linalg_structured_op +def pooling_nhwc_min_poly( + I=TensorDef(T1, S.N, S.H, S.W, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c])) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 20f7d1168c0f122f459437a7405696877b6ec253 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Fri, 2 Jul 2021 10:03:46 -0700 Subject: [PATCH 069/915] Add C API files for the LLVM dialect For now only expose a builder for the LLVM pointer type. Reviewed By: jpienaar, ftynse Differential Revision: https://reviews.llvm.org/D105346 --- mlir/include/mlir-c/Dialect/LLVM.h | 30 ++++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 21 +++++++++++++++++++ 3 files changed, 60 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/LLVM.h create mode 100644 mlir/lib/CAPI/Dialect/LLVM.cpp diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h new file mode 100644 index 000000000..d3c5217ea --- /dev/null +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -0,0 +1,30 @@ +//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_LLVM_H +#define MLIR_C_DIALECT_LLVM_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); + +/// Creates an llvm.ptr type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirType pointee, + unsigned addressSpace); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_LLVM_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 053fce30d..ab8ac73c7 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -27,6 +27,15 @@ add_mlir_public_c_api_library(MLIRCAPIGPU MLIRPass ) +add_mlir_public_c_api_library(MLIRCAPILLVM + LLVM.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRLLVMIR +) + add_mlir_public_c_api_library(MLIRCAPILinalg Linalg.cpp LinalgPasses.cpp diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp new file mode 100644 index 000000000..be0e6c5d0 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -0,0 +1,21 @@ +//===- LLVM.cpp - C Interface for LLVM dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/LLVM.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMTypes.h" + +using namespace mlir; +using namespace mlir::LLVM; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(LLVM, llvm, LLVMDialect) + +MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace) { + return wrap(LLVMPointerType::get(unwrap(pointee), addressSpace)); +} From 27c63392db46a9f0b7dedfadfc79d9c412f7535a Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Wed, 23 Jun 2021 10:17:46 +0530 Subject: [PATCH 070/915] [MLIR] Split out GPU ops library from Transforms Split out GPU ops library from GPU transforms. This allows libraries to depend on GPU Ops without needing/building its transforms. Differential Revision: https://reviews.llvm.org/D105472 --- mlir/lib/CAPI/Dialect/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index ab8ac73c7..801b0f77a 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -23,7 +23,7 @@ add_mlir_public_c_api_library(MLIRCAPIGPU LINK_LIBS PUBLIC MLIRCAPIIR - MLIRGPU + MLIRGPUTransforms MLIRPass ) From 21bbaa5dc1870aad7cff72e264c14c3ecd15c588 Mon Sep 17 00:00:00 2001 From: Bairen Yi Date: Wed, 7 Jul 2021 11:26:50 +0200 Subject: [PATCH 071/915] [mlir][CAPI] Export mlirValueEqual in C API Somehow it is not exported in C API. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D105422 --- mlir/include/mlir-c/IR.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b0866385c..6924fa88d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -557,7 +557,7 @@ mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); static inline bool mlirValueIsNull(MlirValue value) { return !value.ptr; } /// Returns 1 if two values are equal, 0 otherwise. -bool mlirValueEqual(MlirValue value1, MlirValue value2); +MLIR_CAPI_EXPORTED bool mlirValueEqual(MlirValue value1, MlirValue value2); /// Returns 1 if the value is a block argument, 0 otherwise. MLIR_CAPI_EXPORTED bool mlirValueIsABlockArgument(MlirValue value); From c8e0ad9b0d0101ac1f55b6d8580141d35163219e Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 8 Jul 2021 08:48:23 +0000 Subject: [PATCH 072/915] [mlir][linalg][python] Add exp and log to the OpDSL. Introduce the exp and log function in OpDSL. Add the soft plus operator to test the emitted IR in Python and C++. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105420 --- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 11 +++++++++++ .../dialects/linalg/opdsl/ops/core_named_ops.py | 13 +++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 61d226058..3810df9df 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -7,6 +7,7 @@ from mlir.ir import * from mlir.dialects import linalg from mlir.dialects import std +from mlir.dialects import math # TODO: resolve name collision for Linalg functionality that is injected inside # the _mlir.dialects.linalg directly via pybind. from _mlir.dialects.linalg import fill_builtin_region @@ -293,6 +294,16 @@ def _eval_add(self, lhs: Value, rhs: Value) -> Value: return std.AddIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") + def _eval_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x.type, x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _eval_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x.type, x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.SubFOp(lhs.type, lhs, rhs).result diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index a37e1944c..72793cbf9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -209,3 +209,16 @@ def fill_rng_2d( offset = cast(F64, const(2147483647)) scaling = (max - min) * inv_range O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + + +@linalg_structured_op +def soft_plus_2d( + I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): + """Implements the soft plus operator. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.m, D.n) + O[D.m, D.n] = \ + PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n]))) From 263c66d880e1df8a2126edc23f1fee6cf050b471 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 12 Jul 2021 12:31:30 +0000 Subject: [PATCH 073/915] [mlir][linalg][python] Add auto-generated file warning (NFC). Annotate LinalgNamedStructuredOps.yaml with a comment stating the file is auto-generated and should not be edited manually. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D105809 --- mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py index bacc0c302..05c06e737 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -81,6 +81,7 @@ def main(args): # Print. if args.format == "yaml": + print("# Auto-generated file. Do not edit!") print(yaml_dump_all(configs)) elif args.format == "repr": for config in configs: From a7b47ad35fc13e719a89450fdd0d7d23183f5a14 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Mon, 12 Jul 2021 17:25:55 -0700 Subject: [PATCH 074/915] [mlir][Linalg] Add 3D pooling named ops to Linalg. Reviewed By: gysit, hanchung Differential Revision: https://reviews.llvm.org/D105329 --- .../linalg/opdsl/ops/core_named_ops.py | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 72793cbf9..1362e9f18 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -184,6 +184,62 @@ def pooling_nhwc_min_poly( D.c])) +@linalg_structured_op +def pooling_ndhwc_sum( + I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SD, S.SH, S.SW), + dilations=AttributeDef(S.DD, S.DH, S.DW)): + """Performs 3D sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.c] += cast( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c]) + + +@linalg_structured_op +def pooling_ndhwc_max( + I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SD, S.SH, S.SW), + dilations=AttributeDef(S.DD, S.DH, S.DW)): + """Performs 3D max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( + cast( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_ndhwc_min( + I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SD, S.SH, S.SW), + dilations=AttributeDef(S.DD, S.DH, S.DW)): + """Performs 3D min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( + cast( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c])) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 1c767580e60da97564276f2b4ee2066b13b67bfb Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 13 Jul 2021 14:35:50 -0700 Subject: [PATCH 075/915] Add more types to the LLVM dialect C API This includes: - void type - array types - function types - literal (unnamed) struct types Reviewed By: jpienaar, ftynse Differential Revision: https://reviews.llvm.org/D105908 --- mlir/include/mlir-c/Dialect/LLVM.h | 17 +++++++++++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 25 +++++++++++++++++++++++++ 2 files changed, 42 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index d3c5217ea..2cf73a359 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,23 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace); +/// Creates an llmv.void type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx); + +/// Creates an llvm.array type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType, + unsigned numElements); + +/// Creates an llvm.func type. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg); + +/// Creates an LLVM literal (unnamed) struct type. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index be0e6c5d0..d023bf5d6 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -19,3 +19,28 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(LLVM, llvm, LLVMDialect) MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(pointee), addressSpace)); } + +MlirType mlirLLVMVoidTypeGet(MlirContext ctx) { + return wrap(LLVMVoidType::get(unwrap(ctx))); +} + +MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) { + return wrap(LLVMArrayType::get(unwrap(elementType), numElements)); +} + +MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, + MlirType const *argumentTypes, bool isVarArg) { + SmallVector argumentStorage; + return wrap(LLVMFunctionType::get( + unwrap(resultType), + unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); +} + +MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fieldStorage; + return wrap(LLVMStructType::getLiteral( + unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage), + isPacked)); +} From ef9f251e73ae6d7e3e50eeba748c1c0891dc317e Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 14 Jul 2021 20:19:27 -0700 Subject: [PATCH 076/915] [MLIR] [Python] Add `owner` to PyValue and fix its parent reference Adds `owner` python call to `mlir.ir.Value`. Assuming that `PyValue.parentOperation` is intended to be the value's owner, this fixes the construction of it from `PyOpOperandList`. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D103853 --- mlir/lib/Bindings/Python/IRCore.cpp | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3f08522b9..b5197d9aa 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1652,7 +1652,17 @@ class PyOpOperandList : public Sliceable { } PyValue getElement(intptr_t pos) { - return PyValue(operation, mlirOperationGetOperand(operation->get(), pos)); + MlirValue operand = mlirOperationGetOperand(operation->get(), pos); + MlirOperation owner; + if (mlirValueIsAOpResult(operand)) + owner = mlirOpResultGetOwner(operand); + else if (mlirValueIsABlockArgument(operand)) + owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(operand)); + else + assert(false && "Value must be an block arg or op result."); + PyOperationRef pyOwner = + PyOperation::forOperation(operation->getContext(), owner); + return PyValue(pyOwner, operand); } PyOpOperandList slice(intptr_t startIndex, intptr_t length, intptr_t step) { @@ -2429,6 +2439,15 @@ void mlir::python::populateIRCore(py::module &m) { .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) + .def_property_readonly( + "owner", + [](PyValue &self) { + assert(mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; From 8769f11920761a4f12abb44b91b74880560c1d5c Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Sun, 18 Jul 2021 12:05:46 +0200 Subject: [PATCH 077/915] [MLIR][CAPI] On MINGW don't link against libMLIR Cross-compiling MLIR with MINGW failed because adding libMLIR to the libraries to link against would lead to duplicated symbols. ``` [09:28:14] ninja: job failed: : && /opt/bin/i686-w64-mingw32-libgfortran4-cxx03/i686-w64-mingw32-g++ --sysroot=/opt/i686-w64-mingw32/i686-w64-mingw32/sys-root/ -remap -D__USING_SJLJ_EXCEPTIONS__ -D__CRT__NO_INLINE -fno-gnu-unique -Werror=date-time -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wno-missing-field-initializers -pedantic -Wno-long-long -Wimplicit-fallthrough -Wno-maybe-uninitialized -Wno-noexcept-type -Wdelete-non-virtual-dtor -Wno-comment -O2 -DNDEBUG -shared -o bin/libMLIRPublicAPI.dll -Wl,--out-implib,lib/libMLIRPublicAPI.dll.a -Wl,--major-image-version,0,--minor-image-version,0 tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/AffineExpr.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/AffineMap.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/BuiltinAttributes.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/BuiltinTypes.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/Diagnostics.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/IntegerSet.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/IR.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/Pass.cpp.obj tools/mlir/lib/CAPI/IR/CMakeFiles/obj.MLIRCAPIIR.dir/Support.cpp.obj tools/mlir/lib/CAPI/Registration/CMakeFiles/obj.MLIRCAPIRegistration.dir/Registration.cpp.obj tools/mlir/lib/CAPI/Dialect/CMakeFiles/obj.MLIRCAPILinalg.dir/Linalg.cpp.obj tools/mlir/lib/CAPI/Dialect/CMakeFiles/obj.MLIRCAPISCF.dir/SCF.cpp.obj tools/mlir/lib/CAPI/Dialect/CMakeFiles/obj.MLIRCAPIShape.dir/Shape.cpp.obj tools/mlir/lib/CAPI/Dialect/CMakeFiles/obj.MLIRCAPIStandard.dir/Standard.cpp.obj tools/mlir/lib/CAPI/Dialect/CMakeFiles/obj.MLIRCAPITensor.dir/Tensor.cpp.obj tools/mlir/lib/CAPI/Transforms/CMakeFiles/obj.MLIRCAPITransforms.dir/Passes.cpp.obj lib/libMLIR.dll.a lib/libMLIRIR.a lib/libMLIRParser.a lib/libMLIRSupport.a lib/libMLIRPass.a lib/libMLIRCAPIIR.a lib/libMLIRAffine.a lib/libMLIRAffineEDSC.a lib/libMLIRAffineTransforms.a lib/libMLIRAffineUtils.a lib/libMLIRArmNeon.a lib/libMLIRArmSVE.a lib/libMLIRAsync.a lib/libMLIRAsyncTransforms.a lib/libMLIRAVX512.a lib/libMLIRComplex.a lib/libMLIRGPU.a lib/libMLIRLinalgAnalysis.a lib/libMLIRLinalgEDSC.a lib/libMLIRLinalg.a lib/libMLIRLinalgTransforms.a lib/libMLIRLinalgUtils.a lib/libMLIRLLVMIRTransforms.a lib/libMLIRLLVMIR.a lib/libMLIRLLVMAVX512.a lib/libMLIRLLVMArmNeon.a lib/libMLIRLLVMArmSVE.a lib/libMLIRNVVMIR.a lib/libMLIRROCDLIR.a lib/libMLIROpenACC.a lib/libMLIROpenMP.a lib/libMLIRPDL.a lib/libMLIRPDLInterp.a lib/libMLIRQuant.a lib/libMLIRSCF.a lib/libMLIRSCFTransforms.a lib/libMLIRSDBM.a lib/libMLIRShape.a lib/libMLIRShapeOpsTransforms.a lib/libMLIRSPIRV.a lib/libMLIRSPIRVModuleCombiner.a lib/libMLIRSPIRVConversion.a lib/libMLIRSPIRVTransforms.a lib/libMLIRSPIRVUtils.a lib/libMLIRStandard.a lib/libMLIRStandardOpsTransforms.a lib/libMLIRTensor.a lib/libMLIRTensorTransforms.a lib/libMLIRTosa.a lib/libMLIRTosaTransforms.a lib/libMLIRVector.a lib/libMLIRCAPIIR.a lib/libMLIRLinalg.a lib/libMLIRCAPIIR.a lib/libMLIRSCF.a lib/libMLIRCAPIIR.a lib/libMLIRShape.a lib/libMLIRCAPIIR.a lib/libMLIRStandard.a lib/libMLIRCAPIIR.a lib/libMLIRTensor.a lib/libMLIRTransforms.a lib/libMLIRAsync.a lib/libMLIRAffineUtils.a lib/libMLIRLinalgAnalysis.a lib/libMLIRLinalgEDSC.a lib/libMLIRVectorToSCF.a lib/libMLIRVectorToLLVM.a lib/libMLIRArmNeonToLLVM.a lib/libMLIRArmNeon.a lib/libMLIRLLVMArmNeon.a lib/libMLIRAVX512ToLLVM.a lib/libMLIRAVX512.a lib/libMLIRLLVMAVX512.a lib/libMLIRArmSVEToLLVM.a lib/libMLIRArmSVE.a lib/libMLIRLLVMArmSVE.a lib/libMLIRStandardToLLVM.a lib/libMLIRTargetLLVMIRModuleTranslation.a lib/libMLIRLLVMIRTransforms.a lib/libMLIRLLVMIR.a lib/libMLIROpenMP.a lib/libMLIRTranslation.a lib/libMLIRSPIRVConversion.a lib/libMLIRSPIRV.a lib/libMLIRParser.a lib/libMLIRTransforms.a lib/libMLIRVector.a lib/libMLIRAffineEDSC.a lib/libMLIRLinalg.a lib/libMLIRCopyOpInterface.a lib/libMLIRTosa.a lib/libMLIRQuant.a lib/libMLIRTransformUtils.a lib/libMLIRLoopAnalysis.a lib/libMLIRPresburger.a lib/libMLIRRewrite.a lib/libMLIRPDLToPDLInterp.a lib/libMLIRPass.a lib/libMLIRAnalysis.a lib/libMLIRAffine.a lib/libMLIRSCF.a lib/libMLIRLoopLikeInterface.a lib/libMLIRPDLInterp.a lib/libMLIRPDL.a lib/libMLIRInferTypeOpInterface.a lib/libMLIRStandard.a lib/libMLIRTensor.a lib/libMLIREDSC.a lib/libMLIRCastInterfaces.a lib/libMLIRVectorInterfaces.a lib/libMLIRSideEffectInterfaces.a lib/libMLIRDialect.a lib/libMLIRViewLikeInterface.a lib/libMLIRCallInterfaces.a lib/libMLIRControlFlowInterfaces.a lib/libMLIRIR.a lib/libMLIRSupport.a lib/libLLVM.dll.a -lkernel32 -luser32 -lgdi32 -lwinspool -lshell32 -lole32 -loleaut32 -luuid -lcomdlg32 -ladvapi32 && : [09:28:14] lib/libMLIRAffine.a(AffineOps.cpp.obj):AffineOps.cpp:(.text+0x1d600): multiple definition of `mlir::AffineDialect::initialize()' [09:28:14] lib/libMLIR.dll.a(d008729.o):(.text+0x0): first defined here [09:28:14] lib/libMLIRArmSVE.a(ArmSVEDialect.cpp.obj):ArmSVEDialect.cpp:(.text+0x5be0): multiple definition of `mlir::arm_sve::ArmSVEDialect::initialize()' [09:28:14] lib/libMLIR.dll.a(d039020.o):(.text+0x0): first defined here [09:28:14] lib/libMLIRAsync.a(Async.cpp.obj):Async.cpp:(.text+0xc0d0): multiple definition of `mlir::async::AsyncDialect::initialize()' [09:28:14] lib/libMLIR.dll.a(d023173.o):(.text+0x0): first defined here ... ``` Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D106169 --- mlir/lib/CAPI/CMakeLists.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index db77cc1f6..cd119554f 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -25,13 +25,19 @@ foreach(lib ${public_api_libs}) list(APPEND _DEPS $) endforeach() +if(MINGW) + set(MLIR_LINK_MLIR_DYLIB 0) +else() + set(MLIR_LINK_MLIR_DYLIB ${LLVM_BUILD_LLVM_DYLIB}) +endif() + add_mlir_library(MLIRPublicAPI SHARED ${_OBJECTS} EXCLUDE_FROM_LIBMLIR LINK_LIBS # Dependency on the implementation shared library. - $<$:MLIR> + $<$:MLIR> ${_DEPS} ) From fc72d62afaa49dfaa8cfc9ae000346e80604f152 Mon Sep 17 00:00:00 2001 From: Hanhan Wang Date: Mon, 19 Jul 2021 09:23:55 -0700 Subject: [PATCH 078/915] [mlir][Linalg] Migrate 2D pooling ops from tc definition to yaml definition. This deletes all the pooling ops in LinalgNamedStructuredOpsSpec.tc. All the uses are replaced with the yaml pooling ops. Reviewed By: gysit, rsuderman Differential Revision: https://reviews.llvm.org/D106181 --- .../python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 1362e9f18..f1a2fd3ef 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -132,7 +132,7 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly( @linalg_structured_op -def pooling_nhwc_sum_poly( +def pooling_nhwc_sum( I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), @@ -149,7 +149,7 @@ def pooling_nhwc_sum_poly( @linalg_structured_op -def pooling_nhwc_max_poly( +def pooling_nhwc_max( I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), @@ -167,7 +167,7 @@ def pooling_nhwc_max_poly( @linalg_structured_op -def pooling_nhwc_min_poly( +def pooling_nhwc_min( I=TensorDef(T1, S.N, S.H, S.W, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), From c49ffef0d65bf2fa7f4b358866994e131027873e Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 20 Jul 2021 06:56:05 -0700 Subject: [PATCH 079/915] Exclude pybind11 2.7.0 from MLIR python requirements. Appears to have a broken CMake installation. Reported bug: https://github.com/pybind/pybind11/issues/3136 --- mlir/python/requirements.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 51b35c22e..f76dcf676 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy -pybind11>=2.6.0 +# Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136 +pybind11>=2.6.0,!=2.7.0 PyYAML From fbb1b5590a514a320279ae7f1ce0807cdf7da97b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 28 Jun 2021 14:36:47 -0700 Subject: [PATCH 080/915] [mlir][tosa] Add quantized lowering for matmul and fully_connected Added the named op variants for quantized matmul and quantized batch matmul with the necessary lowerings/tests from tosa's matmul/fully connected ops. Current version does not use the contraction op interface as its verifiers are not compatible with scalar operations. Differential Revision: https://reviews.llvm.org/D105063 --- .../linalg/opdsl/ops/core_named_ops.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index f1a2fd3ef..3ea171f78 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -20,6 +20,22 @@ def matmul( implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.k, D.n]) - cast(U, BZp)) @linalg_structured_op def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), @@ -40,7 +56,6 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), implements(ContractionOpInterface) accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) - @linalg_structured_op def batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), @@ -55,6 +70,23 @@ def batch_matmul( implements(ContractionOpInterface) C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) +@linalg_structured_op +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) + @linalg_structured_op def matvec( From 09adde1542557ca33366455557cdb0e8885a0705 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 20 Jul 2021 15:07:04 -0700 Subject: [PATCH 081/915] [mlir][tosa] Added tosa to linalg lowering to unstrided transposed conv The unstrided transposed conv can be represented as a regular convolution. Lower to this variant to handle the basic case. This includes transitioning from the TC defined convolution operation and a yaml defined one. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D106389 --- .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 3ea171f78..ebe9822b1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -145,6 +145,25 @@ def dot( C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) +@linalg_structured_op +def conv_2d_input_nhwc_filter_ohwi_poly( + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs a 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic) + O[D.n, D.oh, D.ow, D.oc] += cast( + U, I[D.n, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic]) + @linalg_structured_op def depthwise_conv_2d_input_nhwc_filter_hwc_poly( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), From 879fe159f3b8c9e29602f5d4152f0eddc103acf1 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 20 Jul 2021 15:53:15 -0700 Subject: [PATCH 082/915] Remove libMLIRPublicAPI DSO. libMLIRPublicAPI.so came into existence early when the Python and C-API were being co-developed because the Python extensions need a single DSO which exports the C-API to link against. It really should never have been exported as a mondo library in the first place, which has caused no end of problems in different linking modes, etc (i.e. the CAPI tests depended on it). This patch does a mechanical move that: * Makes the C-API tests link directly to their respective libraries. * Creates a libMLIRPythonCAPI as part of the Python bindings which assemble to exact DSO that they need. This has the effect that the C-API is no longer monolithic and can be subset and used piecemeal in a modular fashion, which is necessary for downstreams to only pay for what they use. There are additional, more fundamental changes planned for how the Python API is assembled which should make it more out of tree friendly, but this minimal first step is necessary to break the fragile dependency between the C-API and Python API. Downstream actions required: * If using the C-API and linking against MLIRPublicAPI, you must instead link against its constituent components. As a reference, the Python API dependencies are in lib/Bindings/Python/CMakeLists.txt and approximate the full set of dependencies available. * If you have a Python API project that was previously linking against MLIRPublicAPI (i.e. to add its own C-API DSO), you will want to `s/MLIRPublicAPI/MLIRPythonCAPI/` and all should be as it was. There are larger changes coming in this area but this part is incremental. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D106369 --- mlir/include/mlir-c/Support.h | 10 ++- mlir/lib/Bindings/Python/CMakeLists.txt | 71 +++++++++++++++++++ .../Python/Conversions/CMakeLists.txt | 2 + .../Bindings/Python/Transforms/CMakeLists.txt | 4 +- mlir/lib/CAPI/CMakeLists.txt | 43 ----------- mlir/python/mlir/_cext_loader.py | 2 +- 6 files changed, 86 insertions(+), 46 deletions(-) diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 340f8ec8b..315f6c456 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -22,9 +22,17 @@ //===----------------------------------------------------------------------===// // Visibility annotations. // Use MLIR_CAPI_EXPORTED for exported functions. +// +// On Windows, if MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC is defined, then +// __declspec(dllexport) and __declspec(dllimport) will be generated. This +// can only be enabled if actually building DLLs. It is generally, mutually +// exclusive with the use of other mechanisms for managing imports/exports +// (i.e. CMake's WINDOWS_EXPORT_ALL_SYMBOLS feature). //===----------------------------------------------------------------------===// -#if defined(MLIR_CAPI_DISABLE_VISIBILITY_ANNOTATIONS) +#if (defined(_WIN32) || defined(__CYGWIN__)) && \ + !defined(MLIR_CAPI_ENABLE_WINDOWS_DLL_DECLSPEC) +// Visibility annotations disabled. #define MLIR_CAPI_EXPORTED #elif defined(_WIN32) || defined(__CYGWIN__) // Windows visibility declarations. diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt index 173cf48c0..8a112b7f4 100644 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/CMakeLists.txt @@ -1,6 +1,59 @@ include(AddMLIRPython) add_custom_target(MLIRBindingsPythonExtension) +################################################################################ +# All python extensions must link through one DSO which exports the CAPI, and +# this must have a globally unique name amongst all embeddors of the python +# library since it will effectively have global scope. +# +# The presence of this aggregate library is part of the long term plan, but its +# use needs to be made more flexible. +################################################################################ + +set(public_api_libs + MLIRCAPIConversion + MLIRCAPIDebug + MLIRCEXECUTIONENGINE + MLIRCAPIIR + MLIRCAPIRegistration + MLIRCAPITransforms + + # Dialects + MLIRCAPIAsync + MLIRCAPIGPU + MLIRCAPILinalg + MLIRCAPILLVM + MLIRCAPIShape + MLIRCAPISparseTensor + MLIRCAPIStandard + MLIRCAPISCF + MLIRCAPITensor +) + +foreach(lib ${public_api_libs}) + if(XCODE) + # Xcode doesn't support object libraries, so we have to trick it into + # linking the static libraries instead. + list(APPEND _DEPS "-force_load" ${lib}) + else() + list(APPEND _OBJECTS $) + endif() + # Accumulate transitive deps of each exported lib into _DEPS. + list(APPEND _DEPS $) +endforeach() + +add_mlir_library(MLIRPythonCAPI + PARTIAL_SOURCES_INTENDED + SHARED + ${_OBJECTS} + EXCLUDE_FROM_LIBMLIR + LINK_LIBS + ${_DEPS} +) +if(MSVC) + set_property(TARGET MLIRPythonCAPI PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON) +endif() + ################################################################################ # Build core python extension ################################################################################ @@ -19,6 +72,9 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir PybindUtils.cpp Pass.cpp ExecutionEngine.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) @@ -30,6 +86,9 @@ add_mlir_python_extension(MLIRAllPassesRegistrationBindingsPythonExtension _mlir python SOURCES AllPassesRegistration.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRAllPassesRegistrationBindingsPythonExtension) @@ -38,6 +97,9 @@ add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasse python SOURCES AsyncPasses.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) @@ -46,6 +108,9 @@ add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSpa python SOURCES SparseTensorPasses.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension) @@ -54,6 +119,9 @@ add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses python SOURCES GPUPasses.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRGPUPassesBindingsPythonExtension) @@ -62,5 +130,8 @@ add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPas python SOURCES LinalgPasses.cpp + LINK_LIBS PRIVATE + LLVMSupport + MLIRPythonCAPI ) add_dependencies(MLIRBindingsPythonExtension MLIRLinalgPassesBindingsPythonExtension) diff --git a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt index ad2aeefca..e39707d0c 100644 --- a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt @@ -7,4 +7,6 @@ add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversion python SOURCES Conversions.cpp + LINK_LIBS PRIVATE + MLIRPythonCAPI ) diff --git a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt index 8b53f03d4..b33d1503b 100644 --- a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt +++ b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt @@ -7,4 +7,6 @@ add_mlir_python_extension(MLIRTransformsBindingsPythonExtension _mlirTransforms python SOURCES Transforms.cpp -) \ No newline at end of file + LINK_LIBS PRIVATE + MLIRPythonCAPI +) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index cd119554f..eed3f38d1 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -5,46 +5,3 @@ add_subdirectory(ExecutionEngine) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) - - -################################################################################ -# libMLIRPublicAPI shared library/DLL. -################################################################################ - -get_property(public_api_libs GLOBAL PROPERTY MLIR_PUBLIC_C_API_LIBS) - -foreach(lib ${public_api_libs}) - if(XCODE) - # Xcode doesn't support object libraries, so we have to trick it into - # linking the static libraries instead. - list(APPEND _DEPS "-force_load" ${lib}) - else() - list(APPEND _OBJECTS $) - endif() - # Accumulate transitive deps of each exported lib into _DEPS. - list(APPEND _DEPS $) -endforeach() - -if(MINGW) - set(MLIR_LINK_MLIR_DYLIB 0) -else() - set(MLIR_LINK_MLIR_DYLIB ${LLVM_BUILD_LLVM_DYLIB}) -endif() - -add_mlir_library(MLIRPublicAPI - SHARED - ${_OBJECTS} - EXCLUDE_FROM_LIBMLIR - LINK_LIBS - # Dependency on the implementation shared library. - $<$:MLIR> - ${_DEPS} -) - -target_link_options( - MLIRPublicAPI - PRIVATE - # On Linux, disable re-export of any static linked libraries that - # came through. - $<$:LINKER:--exclude-libs,ALL> -) diff --git a/mlir/python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py index 3a9cde380..c691fccb4 100644 --- a/mlir/python/mlir/_cext_loader.py +++ b/mlir/python/mlir/_cext_loader.py @@ -24,7 +24,7 @@ def _load_extension(name): _load_extension = _mlir_libs.load_extension _preload_dependency = _mlir_libs.preload_dependency -_preload_dependency("MLIRPublicAPI") +_preload_dependency("MLIRPythonCAPI") # Expose the corresponding C-Extension module with a well-known name at this # top-level module. This allows relative imports like the following to From 2144d073427875e44776eadd31409cc39da7c400 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 20 Jul 2021 14:41:37 -0700 Subject: [PATCH 083/915] [mlir][tosa] Quantized Conv2DOp lowering to linalg added. Includes a version of a quantized conv2D operations with a lowering from TOSA to linalg with corresponding test. We keep the quantized and quantized variants as separate named ops to avoid the additional operations for non-quantized convolutions. Differential Revision: https://reviews.llvm.org/D106407 --- .../linalg/opdsl/ops/core_named_ops.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index ebe9822b1..cbb2c0e31 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -164,6 +164,30 @@ def conv_2d_input_nhwc_filter_ohwi_poly( D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic]) +@linalg_structured_op +def conv_2d_input_nhwc_filter_ohwi_poly_q( + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs a 2-D quantized convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Includes zero point + adjustment for quantization. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic) + O[D.n, D.oh, D.ow, D.oc] += ((cast( + U, I[D.n, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic]) - cast(U, IZp)) * + (cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp))) + + @linalg_structured_op def depthwise_conv_2d_input_nhwc_filter_hwc_poly( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), From d19cc8ac55255e8ad6c45ede5a39bc8b19fdeccf Mon Sep 17 00:00:00 2001 From: Yi Zhang Date: Fri, 23 Jul 2021 16:15:21 +0000 Subject: [PATCH 084/915] [mlir][linalg] Add pooling_nchw_max, conv_2d_nchw as yaml ops. - Add pooling_nchw_max. - Move conv_2d_nchw to yaml ops and add strides and dilation attributes. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D106658 --- .../linalg/opdsl/ops/core_named_ops.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index cbb2c0e31..3aa5aadc7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -205,6 +205,23 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) * cast(U, K[D.kh, D.kw, D.c]) +@linalg_structured_op +def conv_2d_nchw( + I=TensorDef(T1, S.N, S.C, S.IH, S.IW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op def pooling_nhwc_sum( @@ -240,6 +257,22 @@ def pooling_nhwc_max( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.H, S.W), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( + cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + ])) @linalg_structured_op def pooling_nhwc_min( From 7fb17fe58feb59b4b50eac385d08c17953507367 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 22 Jul 2021 19:57:41 +0000 Subject: [PATCH 085/915] Re-engineer MLIR python build support. * Implements all of the discussed features: - Links against common CAPI libraries that are self contained. - Stops using the 'python/' directory at the root for everything, opening the namespace up for multiple projects to embed the MLIR python API. - Separates declaration of sources (py and C++) needed to build the extension from building, allowing external projects to build custom assemblies from core parts of the API. - Makes the core python API relocatable (i.e. it could be embedded as something like 'npcomp.ir', 'npcomp.dialects', etc). Still a bit more to do to make it truly isolated but the main structural reset is done. - When building statically, installed python packages are completely self contained, suitable for direct setup and upload to PyPi, et al. - Lets external projects assemble their own CAPI common runtime library that all extensions use. No more possibilities for TypeID issues. - Begins modularizing the API so that external projects that just include a piece pay only for what they use. * I also rolled in a re-organization of the native libraries that matches how I was packaging these out of tree and is a better layering (i.e. all libraries go into a nested _mlir_libs package). There is some further cleanup that I resisted since it would have required source changes that I'd rather do in a followup once everything stabilizes. * Note that I made a somewhat odd choice in choosing to recompile all extensions for each project they are included into (as opposed to compiling once and just linking). While not leveraged yet, this will let us set definitions controlling the namespacing of the extensions so that they can be made to not conflict across projects (with preprocessor definitions). * This will be a relatively substantial breaking change for downstreams. I will handle the npcomp migration and will coordinate with the circt folks before landing. We should stage this and make sure it isn't causing problems before landing. * Fixed a couple of absolute imports that were causing issues. Differential Revision: https://reviews.llvm.org/D106520 --- mlir/lib/Bindings/Python/CMakeLists.txt | 137 ------- .../Python/Conversions/CMakeLists.txt | 12 - .../Bindings/Python/Transforms/CMakeLists.txt | 12 - mlir/python/CMakeLists.txt | 333 ++++++++++++++++-- mlir/python/mlir/_cext_loader.py | 27 +- mlir/python/mlir/_mlir_libs/__init__.py | 21 ++ mlir/python/mlir/dialects/CMakeLists.txt | 86 ----- mlir/python/mlir/dialects/PythonTest.td | 33 ++ mlir/python/mlir/dialects/_builtin_ops_ext.py | 9 +- mlir/python/mlir/dialects/_linalg_ops_ext.py | 16 +- .../dialects/linalg/opdsl/lang/emitter.py | 3 +- 11 files changed, 382 insertions(+), 307 deletions(-) delete mode 100644 mlir/lib/Bindings/Python/CMakeLists.txt delete mode 100644 mlir/lib/Bindings/Python/Conversions/CMakeLists.txt delete mode 100644 mlir/lib/Bindings/Python/Transforms/CMakeLists.txt create mode 100644 mlir/python/mlir/_mlir_libs/__init__.py delete mode 100644 mlir/python/mlir/dialects/CMakeLists.txt create mode 100644 mlir/python/mlir/dialects/PythonTest.td diff --git a/mlir/lib/Bindings/Python/CMakeLists.txt b/mlir/lib/Bindings/Python/CMakeLists.txt deleted file mode 100644 index 8a112b7f4..000000000 --- a/mlir/lib/Bindings/Python/CMakeLists.txt +++ /dev/null @@ -1,137 +0,0 @@ -include(AddMLIRPython) -add_custom_target(MLIRBindingsPythonExtension) - -################################################################################ -# All python extensions must link through one DSO which exports the CAPI, and -# this must have a globally unique name amongst all embeddors of the python -# library since it will effectively have global scope. -# -# The presence of this aggregate library is part of the long term plan, but its -# use needs to be made more flexible. -################################################################################ - -set(public_api_libs - MLIRCAPIConversion - MLIRCAPIDebug - MLIRCEXECUTIONENGINE - MLIRCAPIIR - MLIRCAPIRegistration - MLIRCAPITransforms - - # Dialects - MLIRCAPIAsync - MLIRCAPIGPU - MLIRCAPILinalg - MLIRCAPILLVM - MLIRCAPIShape - MLIRCAPISparseTensor - MLIRCAPIStandard - MLIRCAPISCF - MLIRCAPITensor -) - -foreach(lib ${public_api_libs}) - if(XCODE) - # Xcode doesn't support object libraries, so we have to trick it into - # linking the static libraries instead. - list(APPEND _DEPS "-force_load" ${lib}) - else() - list(APPEND _OBJECTS $) - endif() - # Accumulate transitive deps of each exported lib into _DEPS. - list(APPEND _DEPS $) -endforeach() - -add_mlir_library(MLIRPythonCAPI - PARTIAL_SOURCES_INTENDED - SHARED - ${_OBJECTS} - EXCLUDE_FROM_LIBMLIR - LINK_LIBS - ${_DEPS} -) -if(MSVC) - set_property(TARGET MLIRPythonCAPI PROPERTY WINDOWS_EXPORT_ALL_SYMBOLS ON) -endif() - -################################################################################ -# Build core python extension -################################################################################ -add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir - INSTALL_DIR - python - SOURCES - DialectLinalg.cpp - DialectSparseTensor.cpp - MainModule.cpp - IRAffine.cpp - IRAttributes.cpp - IRCore.cpp - IRModule.cpp - IRTypes.cpp - PybindUtils.cpp - Pass.cpp - ExecutionEngine.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension) - -add_subdirectory(Transforms) -add_subdirectory(Conversions) - -add_mlir_python_extension(MLIRAllPassesRegistrationBindingsPythonExtension _mlirAllPassesRegistration - INSTALL_DIR - python - SOURCES - AllPassesRegistration.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRAllPassesRegistrationBindingsPythonExtension) - -add_mlir_python_extension(MLIRAsyncPassesBindingsPythonExtension _mlirAsyncPasses - INSTALL_DIR - python - SOURCES - AsyncPasses.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRAsyncPassesBindingsPythonExtension) - -add_mlir_python_extension(MLIRSparseTensorPassesBindingsPythonExtension _mlirSparseTensorPasses - INSTALL_DIR - python - SOURCES - SparseTensorPasses.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRSparseTensorPassesBindingsPythonExtension) - -add_mlir_python_extension(MLIRGPUPassesBindingsPythonExtension _mlirGPUPasses - INSTALL_DIR - python - SOURCES - GPUPasses.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRGPUPassesBindingsPythonExtension) - -add_mlir_python_extension(MLIRLinalgPassesBindingsPythonExtension _mlirLinalgPasses - INSTALL_DIR - python - SOURCES - LinalgPasses.cpp - LINK_LIBS PRIVATE - LLVMSupport - MLIRPythonCAPI -) -add_dependencies(MLIRBindingsPythonExtension MLIRLinalgPassesBindingsPythonExtension) diff --git a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt b/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt deleted file mode 100644 index e39707d0c..000000000 --- a/mlir/lib/Bindings/Python/Conversions/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -################################################################################ -# Build python extension -################################################################################ - -add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversions - INSTALL_DIR - python - SOURCES - Conversions.cpp - LINK_LIBS PRIVATE - MLIRPythonCAPI -) diff --git a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt b/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt deleted file mode 100644 index b33d1503b..000000000 --- a/mlir/lib/Bindings/Python/Transforms/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -################################################################################ -# Build python extension -################################################################################ - -add_mlir_python_extension(MLIRTransformsBindingsPythonExtension _mlirTransforms - INSTALL_DIR - python - SOURCES - Transforms.cpp - LINK_LIBS PRIVATE - MLIRPythonCAPI -) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1b9705a03..15b181389 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -1,49 +1,310 @@ +include(AddMLIRPython) + ################################################################################ -# Copy python source tree. +# Structural groupings. ################################################################################ -file(GLOB_RECURSE PY_SRC_FILES - RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" - "${CMAKE_CURRENT_SOURCE_DIR}/mlir/*.py") +declare_mlir_python_sources(MLIRPythonSources) +declare_mlir_python_sources(MLIRPythonSources.Dialects + ADD_TO_PARENT MLIRPythonSources) + +declare_mlir_python_sources(MLIRPythonTestSources) +declare_mlir_python_sources(MLIRPythonTestSources.Dialects + ADD_TO_PARENT MLIRPythonTestSources) -add_custom_target(MLIRBindingsPythonSources ALL - DEPENDS - ${PY_SRC_FILES} +################################################################################ +# Pure python sources and generated code +################################################################################ + +declare_mlir_python_sources(MLIRPythonSources.Core + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources + SOURCES + _cext_loader.py + _dlloader.py + _mlir_libs/__init__.py + ir.py + passmanager.py + dialects/_ods_common.py ) -foreach(PY_SRC_FILE ${PY_SRC_FILES}) - set(PY_DEST_FILE "${PROJECT_BINARY_DIR}/python/${PY_SRC_FILE}") - get_filename_component(PY_DEST_DIR "${PY_DEST_FILE}" DIRECTORY) - file(MAKE_DIRECTORY "${PY_DEST_DIR}") - add_custom_command( - TARGET MLIRBindingsPythonSources PRE_BUILD - COMMENT "Copying python source ${PY_SRC_FILE} -> ${PY_DEST_FILE}" - DEPENDS "${PY_SRC_FILE}" - BYPRODUCTS "${PY_DEST_FILE}" - COMMAND "${CMAKE_COMMAND}" -E create_symlink - "${CMAKE_CURRENT_SOURCE_DIR}/${PY_SRC_FILE}" "${PY_DEST_FILE}" - ) -endforeach() +declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources + SOURCES + execution_engine.py + SOURCES_GLOB + runtime/*.py +) + +declare_mlir_python_sources(MLIRPythonSources.Passes + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources + SOURCES_GLOB + all_passes_registration/*.py + conversions/*.py + transforms/*.py +) + +################################################################################ +# Dialect bindings +################################################################################ + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AsyncOps.td + SOURCES_GLOB dialects/async_dialect/*.py + DIALECT_NAME async_dialect) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BuiltinOps.td + SOURCES + dialects/builtin.py + dialects/_builtin_ops_ext.py + DIALECT_NAME builtin) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/GPUOps.td + SOURCES_GLOB dialects/gpu/*.py + DIALECT_NAME gpu) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LinalgOps.td + SOURCES + dialects/_linalg_ops_ext.py + SOURCES_GLOB + dialects/linalg/*.py + DIALECT_NAME linalg + DEPENDS LinalgOdsGen) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MathOps.td + SOURCES dialects/math.py + DIALECT_NAME math) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MemRefOps.td + SOURCES dialects/memref.py + DIALECT_NAME memref) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/PythonTest.td + SOURCES dialects/python_test.py + DIALECT_NAME python_test) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ShapeOps.td + SOURCES dialects/shape.py + DIALECT_NAME shape) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES dialects/sparse_tensor.py + DIALECT_NAME sparse_tensor) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/StandardOps.td + SOURCES dialects/std.py + DIALECT_NAME std) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TensorOps.td + SOURCES dialects/tensor.py + DIALECT_NAME tensor) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TosaOps.td + SOURCES dialects/tosa.py + DIALECT_NAME tosa) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/VectorOps.td + SOURCES dialects/vector.py + DIALECT_NAME vector) + +################################################################################ +# Python extensions. +# The sources for these are all in lib/Bindings/Python, but since they have to +# be rebuilt for each package and integrate with the source setup here, we +# just reference them here instead of having ordered, cross package target +# dependencies. +################################################################################ + +set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") +declare_mlir_python_extension(MLIRPythonExtension.Core + MODULE_NAME _mlir + ADD_TO_PARENT MLIRPythonSources.Core + SOURCES + ${PYTHON_SOURCE_DIR}/DialectLinalg.cpp # TODO: Break this out. + ${PYTHON_SOURCE_DIR}/DialectSparseTensor.cpp # TODO: Break this out. + ${PYTHON_SOURCE_DIR}/MainModule.cpp + ${PYTHON_SOURCE_DIR}/IRAffine.cpp + ${PYTHON_SOURCE_DIR}/IRAttributes.cpp + ${PYTHON_SOURCE_DIR}/IRCore.cpp + ${PYTHON_SOURCE_DIR}/IRModule.cpp + ${PYTHON_SOURCE_DIR}/IRTypes.cpp + ${PYTHON_SOURCE_DIR}/PybindUtils.cpp + ${PYTHON_SOURCE_DIR}/Pass.cpp + ${PYTHON_SOURCE_DIR}/ExecutionEngine.cpp # TODO: Break this out. + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIDebug + MLIRCAPIIR + MLIRCAPIRegistration # TODO: See about dis-aggregating + + # Dialects + MLIRCAPILinalg # TODO: Remove when above is removed. + MLIRCAPISparseTensor # TODO: Remove when above is removed. + MLIRCAPIStandard + + # Execution engine (remove once disaggregated). + MLIRCEXECUTIONENGINE +) + +declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration + MODULE_NAME _mlirAllPassesRegistration + SOURCES + ${PYTHON_SOURCE_DIR}/AllPassesRegistration.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIConversion + MLIRCAPITransforms +) + +declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses + MODULE_NAME _mlirAsyncPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect + SOURCES + ${PYTHON_SOURCE_DIR}/AsyncPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIAsync +) + +declare_mlir_python_extension(MLIRPythonExtension.Conversions + MODULE_NAME _mlirConversions + ADD_TO_PARENT MLIRPythonSources.Passes + SOURCES + ${PYTHON_SOURCE_DIR}/Conversions/Conversions.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIConversion +) + +declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses + MODULE_NAME _mlirGPUPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + SOURCES + ${PYTHON_SOURCE_DIR}/GPUPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIGPU +) + +declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses + MODULE_NAME _mlirLinalgPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + SOURCES + ${PYTHON_SOURCE_DIR}/LinalgPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPILinalg +) -# Note that we copy from the source tree just like for headers because -# it will not be polluted with py_cache runtime artifacts (from testing and -# such). -install( - DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlir - DESTINATION python - COMPONENT MLIRBindingsPythonSources - FILES_MATCHING PATTERN "*.py" +declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses + MODULE_NAME _mlirSparseTensorPasses + ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + SOURCES + ${PYTHON_SOURCE_DIR}/SparseTensorPasses.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPISparseTensor ) -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonSources - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonSources) -endif() +declare_mlir_python_extension(MLIRPythonExtension.Transforms + MODULE_NAME _mlirTransforms + ADD_TO_PARENT MLIRPythonSources.Passes + SOURCES + ${PYTHON_SOURCE_DIR}/Transforms/Transforms.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPITransforms +) ################################################################################ -# Generated sources. +# Common CAPI dependency DSO. +# All python extensions must link through one DSO which exports the CAPI, and +# this must have a globally unique name amongst all embeddors of the python +# library since it will effectively have global scope. +# +# The presence of this aggregate library is part of the long term plan, but its +# use needs to be made more flexible. +# +# TODO: Upgrade to the aggregate utility in https://reviews.llvm.org/D106419 +# once ready. ################################################################################ -add_subdirectory(mlir/dialects) +add_mlir_python_common_capi_library(MLIRPythonCAPI + INSTALL_COMPONENT MLIRPythonModules + INSTALL_DESTINATION python_packages/mlir_core/mlir/_mlir_libs + OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT "../../../.." + DECLARED_SOURCES + MLIRPythonSources + MLIRPythonExtension.AllPassesRegistration +) + +################################################################################ +# The fully assembled package of modules. +# This must come last. +################################################################################ + +add_mlir_python_modules(MLIRPythonModules + ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" + INSTALL_PREFIX "python_packages/mlir_core/mlir" + DECLARED_SOURCES + MLIRPythonSources + MLIRPythonExtension.AllPassesRegistration + COMMON_CAPI_LINK_LIBS + MLIRPythonCAPI + ) + + +add_mlir_python_modules(MLIRPythonTestModules + ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir" + INSTALL_PREFIX "python_packages/mlir_test/mlir" + DECLARED_SOURCES + MLIRPythonTestSources + ) diff --git a/mlir/python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py index c691fccb4..5f2de7f00 100644 --- a/mlir/python/mlir/_cext_loader.py +++ b/mlir/python/mlir/_cext_loader.py @@ -3,28 +3,27 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """Common module for looking up and manipulating C-Extensions.""" -# Packaged installs have a top-level _mlir_libs package with symbols: -# load_extension(name): Loads a named extension module -# preload_dependency(public_name): Loads a shared-library/DLL into the -# namespace. TODO: Remove this in favor of a more robust mechanism. -# Conditionally switch based on whether we are in a package context. +# The normal layout is to have a nested _mlir_libs package that contains +# all native libraries and extensions. If that exists, use it, but also fallback +# to old behavior where extensions were at the top level as loose libraries. +# TODO: Remove the fallback once downstreams adapt. try: - import _mlir_libs + from ._mlir_libs import * + # TODO: Remove these aliases once everything migrates + _preload_dependency = preload_dependency + _load_extension = load_extension except ModuleNotFoundError: # Assume that we are in-tree. # The _dlloader takes care of platform specific setup before we try to # load a shared library. - from ._dlloader import preload_dependency as _preload_dependency + # TODO: Remove _dlloader once all consolidated on the _mlir_libs approach. + from ._dlloader import preload_dependency - def _load_extension(name): + def load_extension(name): import importlib return importlib.import_module(name) # i.e. '_mlir' at the top level -else: - # Packaged distribution. - _load_extension = _mlir_libs.load_extension - _preload_dependency = _mlir_libs.preload_dependency -_preload_dependency("MLIRPythonCAPI") +preload_dependency("MLIRPythonCAPI") # Expose the corresponding C-Extension module with a well-known name at this # top-level module. This allows relative imports like the following to @@ -32,7 +31,7 @@ def _load_extension(name): # from .._cext_loader import _cext # This reduces coupling, allowing embedding of the python sources into another # project that can just vary based on this top-level loader module. -_cext = _load_extension("_mlir") +_cext = load_extension("_mlir") def _reexport_cext(cext_module_name, target_module_name): diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py new file mode 100644 index 000000000..54ed5e8ef --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -0,0 +1,21 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import importlib +import os + +__all__ = [ + "load_extension", + "preload_dependency", +] + +_this_dir = os.path.dirname(__file__) + +def load_extension(name): + return importlib.import_module(f".{name}", __package__) + + +def preload_dependency(public_name): + # TODO: Implement this hook to pre-load DLLs with ctypes on Windows. + pass diff --git a/mlir/python/mlir/dialects/CMakeLists.txt b/mlir/python/mlir/dialects/CMakeLists.txt deleted file mode 100644 index 3c0434475..000000000 --- a/mlir/python/mlir/dialects/CMakeLists.txt +++ /dev/null @@ -1,86 +0,0 @@ -include(AddMLIRPython) - -################################################################################ -# Generate dialect-specific bindings. -################################################################################ - -add_mlir_dialect_python_bindings(MLIRBindingsPythonAsyncOps - TD_FILE AsyncOps.td - DIALECT_NAME async_dialect) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonAsyncOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonBuiltinOps - TD_FILE BuiltinOps.td - DIALECT_NAME builtin) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonBuiltinOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonGPUOps - TD_FILE GPUOps.td - DIALECT_NAME gpu) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonGPUOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonLinalgOps - TD_FILE LinalgOps.td - DIALECT_NAME linalg - DEPENDS LinalgOdsGen) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonLinalgOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonMathOps - TD_FILE MathOps.td - DIALECT_NAME math) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMathOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonMemRefOps - TD_FILE MemRefOps.td - DIALECT_NAME memref) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonMemRefOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonShapeOps - TD_FILE ShapeOps.td - DIALECT_NAME shape) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonShapeOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonStandardOps - TD_FILE StandardOps.td - DIALECT_NAME std) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonStandardOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonTensorOps - TD_FILE TensorOps.td - DIALECT_NAME tensor) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTensorOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonTosaOps - TD_FILE TosaOps.td - DIALECT_NAME tosa) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonTosaOps) - -add_mlir_dialect_python_bindings(MLIRBindingsPythonVectorOps - TD_FILE VectorOps.td - DIALECT_NAME vector) -add_dependencies(MLIRBindingsPythonSources MLIRBindingsPythonVectorOps) - -################################################################################ -# Installation. -################################################################################ - -# Dialect sources are generated. Install separately. -# Note that __pycache__ directories may have been left by tests and other -# executions. And __init__.py is handled as a regular source file. -# TODO: Eliminate this glob install, instead adding INSTALL_COMPONENT to -# add_mlir_dialect_python_bindings and installing the precise file there. -install( - DIRECTORY ${PROJECT_BINARY_DIR}/python/mlir/dialects - DESTINATION python/mlir - COMPONENT MLIRBindingsPythonDialects - FILES_MATCHING PATTERN "_*_gen.py" - PATTERN "__pycache__" EXCLUDE - PATTERN "__init__.py" EXCLUDE -) - -if (NOT LLVM_ENABLE_IDE) - add_llvm_install_targets( - install-MLIRBindingsPythonDialects - DEPENDS MLIRBindingsPythonSources - COMPONENT MLIRBindingsPythonDialects) -endif() diff --git a/mlir/python/mlir/dialects/PythonTest.td b/mlir/python/mlir/dialects/PythonTest.td new file mode 100644 index 000000000..d3d49395a --- /dev/null +++ b/mlir/python/mlir/dialects/PythonTest.td @@ -0,0 +1,33 @@ +//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_TEST_OPS +#define PYTHON_TEST_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/IR/OpBase.td" + +def Python_Test_Dialect : Dialect { + let name = "python_test"; + let cppNamespace = "PythonTest"; +} +class TestOp traits = []> + : Op; + +def AttributedOp : TestOp<"attributed_op"> { + let arguments = (ins I32Attr:$mandatory_i32, + OptionalAttr:$optional_i32, + UnitAttr:$unit); +} + +def PropertyOp : TestOp<"property_op"> { + let arguments = (ins I32Attr:$property, + I32:$idx); +} + +#endif // PYTHON_TEST_OPS diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index 6598efe3e..99783d833 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -2,11 +2,14 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional, Sequence +try: + from typing import Optional, Sequence -import inspect + import inspect -from ..ir import * + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e class ModuleOp: diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index bce4e08ae..656992cac 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -2,12 +2,16 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional, Sequence, Union -from ..ir import * -from ._ods_common import get_default_loc_context -# TODO: resolve name collision for Linalg functionality that is injected inside -# the _mlir.dialects.linalg directly via pybind. -from _mlir.dialects.linalg import fill_builtin_region +try: + from typing import Optional, Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context + # TODO: resolve name collision for Linalg functionality that is injected inside + # the _mlir.dialects.linalg directly via pybind. + from .._cext_loader import _cext + fill_builtin_region = _cext.dialects.linalg.fill_builtin_region +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e def isa(cls: Type, ty: Type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 3810df9df..4568298d1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -10,7 +10,8 @@ from mlir.dialects import math # TODO: resolve name collision for Linalg functionality that is injected inside # the _mlir.dialects.linalg directly via pybind. -from _mlir.dialects.linalg import fill_builtin_region +from ....._cext_loader import _cext +fill_builtin_region = _cext.dialects.linalg.fill_builtin_region from .scalar_expr import * from .config import * From b3920a466f2120a556ad3eb88012c3eca40daa6f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 28 Jul 2021 20:32:47 +0000 Subject: [PATCH 086/915] [mlir] Set the namespace of the BuiltinDialect to 'builtin' Historically the builtin dialect has had an empty namespace. This has unfortunately created a very awkward situation, where many utilities either have to special case the empty namespace, or just don't work at all right now. This revision adds a namespace to the builtin dialect, and starts to cleanup some of the utilities to no longer handle empty namespaces. For now, the assembly form of builtin operations does not require the `builtin.` prefix. (This should likely be re-evaluated though) Differential Revision: https://reviews.llvm.org/D105149 --- mlir/lib/Bindings/Python/IRCore.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b5197d9aa..1f8155000 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -643,11 +643,8 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { MlirDialect PyDialects::getDialectForKey(const std::string &key, bool attrError) { - // If the "std" dialect was asked for, substitute the empty namespace :( - static const std::string emptyKey; - const std::string *canonKey = key == "std" ? &emptyKey : &key; - MlirDialect dialect = mlirContextGetOrLoadDialect( - getContext()->get(), {canonKey->data(), canonKey->size()}); + MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), + {key.data(), key.size()}); if (mlirDialectIsNull(dialect)) { throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, Twine("Dialect '") + key + "' not found"); From c7969037cfe3ab631a837633040283764e872ac4 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 28 Jul 2021 20:00:02 +0000 Subject: [PATCH 087/915] Break apart the MLIR ExecutionEngine from core python module. * For python projects that don't need JIT/ExecutionEngine, cuts the number of files to compile roughly in half (with similar reduction in end binary size). Differential Revision: https://reviews.llvm.org/D106992 --- mlir/lib/Bindings/Python/ExecutionEngine.h | 22 ------------------- ...onEngine.cpp => ExecutionEngineModule.cpp} | 14 ++++++------ mlir/lib/Bindings/Python/MainModule.cpp | 6 ----- mlir/python/CMakeLists.txt | 15 +++++++++---- mlir/python/mlir/execution_engine.py | 9 ++++++-- 5 files changed, 25 insertions(+), 41 deletions(-) delete mode 100644 mlir/lib/Bindings/Python/ExecutionEngine.h rename mlir/lib/Bindings/Python/{ExecutionEngine.cpp => ExecutionEngineModule.cpp} (92%) diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.h b/mlir/lib/Bindings/Python/ExecutionEngine.h deleted file mode 100644 index cc61648b5..000000000 --- a/mlir/lib/Bindings/Python/ExecutionEngine.h +++ /dev/null @@ -1,22 +0,0 @@ -//===- ExecutionEngine.h - ExecutionEngine submodule of pybind module -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H -#define MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H - -#include "PybindUtils.h" - -namespace mlir { -namespace python { - -void populateExecutionEngineSubmodule(pybind11::module &m); - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H diff --git a/mlir/lib/Bindings/Python/ExecutionEngine.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp similarity index 92% rename from mlir/lib/Bindings/Python/ExecutionEngine.cpp rename to mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 089c29507..510e3f8dd 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngine.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -1,4 +1,4 @@ -//===- ExecutionEngine.cpp - Python MLIR ExecutionEngine Bindings ---------===// +//===- ExecutionEngineModule.cpp - Python module for execution engine -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "ExecutionEngine.h" - -#include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/ExecutionEngine.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; using namespace mlir; @@ -54,18 +52,20 @@ class PyExecutionEngine { } // anonymous namespace /// Create the `mlir.execution_engine` module here. -void mlir::python::populateExecutionEngineSubmodule(py::module &m) { +PYBIND11_MODULE(_mlirExecutionEngine, m) { + m.doc() = "MLIR Execution Engine"; + //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine") - .def(py::init<>([](PyModule &module, int optLevel, + .def(py::init<>([](MlirModule module, int optLevel, const std::vector &sharedLibPaths) { llvm::SmallVector libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( - module.get(), optLevel, libPaths.size(), libPaths.data()); + module, optLevel, libPaths.size(), libPaths.data()); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6e861c2f2..073ac9037 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -11,7 +11,6 @@ #include "PybindUtils.h" #include "Dialects.h" -#include "ExecutionEngine.h" #include "Globals.h" #include "IRModule.h" #include "Pass.h" @@ -93,11 +92,6 @@ PYBIND11_MODULE(_mlir, m) { m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); - // Define and populate ExecutionEngine submodule. - auto executionEngineModule = - m.def_submodule("execution_engine", "MLIR JIT Execution Engine"); - populateExecutionEngineSubmodule(executionEngineModule); - // Define and populate dialect submodules. auto dialectsModule = m.def_submodule("dialects"); auto linalgModule = dialectsModule.def_submodule("linalg"); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 15b181389..f5b261e8a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -169,7 +169,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core ${PYTHON_SOURCE_DIR}/IRTypes.cpp ${PYTHON_SOURCE_DIR}/PybindUtils.cpp ${PYTHON_SOURCE_DIR}/Pass.cpp - ${PYTHON_SOURCE_DIR}/ExecutionEngine.cpp # TODO: Break this out. PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -181,9 +180,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPILinalg # TODO: Remove when above is removed. MLIRCAPISparseTensor # TODO: Remove when above is removed. MLIRCAPIStandard - - # Execution engine (remove once disaggregated). - MLIRCEXECUTIONENGINE ) declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration @@ -219,6 +215,17 @@ declare_mlir_python_extension(MLIRPythonExtension.Conversions MLIRCAPIConversion ) +declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine + MODULE_NAME _mlirExecutionEngine + ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + SOURCES + ${PYTHON_SOURCE_DIR}/ExecutionEngineModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCEXECUTIONENGINE +) + declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py index 39d9501d9..f3bcd0e0d 100644 --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -3,10 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Simply a wrapper around the extension module of the same name. -from ._cext_loader import _cext +from ._cext_loader import load_extension +_execution_engine = load_extension("_mlirExecutionEngine") import ctypes -class ExecutionEngine(_cext.execution_engine.ExecutionEngine): +__all__ = [ + "ExecutionEngine", +] + +class ExecutionEngine(_execution_engine.ExecutionEngine): def lookup(self, name): """Lookup a function emitted with the `llvm.emit_c_interface` From 8c5f19f831189ade0345a6d1b6649e9507935422 Mon Sep 17 00:00:00 2001 From: "Ahmed S. Taei" Date: Wed, 28 Jul 2021 00:35:22 +0000 Subject: [PATCH 088/915] Rorder mmt4d iteration domain Move tile iterators to outer most dim Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D107003 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 3aa5aadc7..e38bc64d4 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -52,7 +52,7 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0. """ - domain(D.m, D.m0, D.n, D.n0, D.k, D.k0) + domain(D.m, D.n, D.m0, D.n0, D.k, D.k0) implements(ContractionOpInterface) accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) From 387ee967a957ab0e91bc621f0d256480145cf717 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 29 Jul 2021 19:06:22 +0000 Subject: [PATCH 089/915] [MLIR][python] Export CAPI headers. * Adds source targets (not included in the full set that downstreams use by default) to bundle mlir-c/ headers into the mlir/_mlir_libs/include directory. * Adds a minimal entry point to get include and library directories. * Used by npcomp to export a full CAPI (which is then used by the Torch extension to link npcomp). Reviewed By: mikeurbach Differential Revision: https://reviews.llvm.org/D107090 --- mlir/python/CMakeLists.txt | 7 +++++++ mlir/python/mlir/_mlir_libs/__init__.py | 20 ++++++++++++++++++++ 2 files changed, 27 insertions(+) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index f5b261e8a..9f66aa9c2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -46,6 +46,12 @@ declare_mlir_python_sources(MLIRPythonSources.Passes transforms/*.py ) +declare_mlir_python_sources(MLIRPythonCAPIHeaderSources + ROOT_DIR "${MLIR_SOURCE_DIR}/include" + SOURCES_GLOB "mlir-c/*.h" + DEST_PREFIX "_mlir_libs/include" +) + ################################################################################ # Dialect bindings ################################################################################ @@ -304,6 +310,7 @@ add_mlir_python_modules(MLIRPythonModules DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.AllPassesRegistration + MLIRPythonCAPIHeaderSources COMMON_CAPI_LINK_LIBS MLIRPythonCAPI ) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 54ed5e8ef..55139b2a8 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -2,6 +2,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Sequence + import importlib import os @@ -19,3 +21,21 @@ def load_extension(name): def preload_dependency(public_name): # TODO: Implement this hook to pre-load DLLs with ctypes on Windows. pass + + +def get_lib_dirs() -> Sequence[str]: + """Gets the lib directory for linking to shared libraries. + + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] + + +def get_include_dirs() -> Sequence[str]: + """Gets the include directory for compiling against exported C libraries. + + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] From cc86ecf7979f3078bae3692402fe1000c8fd15e6 Mon Sep 17 00:00:00 2001 From: Ahmed Taei Date: Mon, 2 Aug 2021 13:26:03 -0700 Subject: [PATCH 090/915] Reorder mmt4d r.h.s operand layout Switch r.h.s operand layout (n1, k1, n0, k0) -> (n1, k1, k0, n0) which is more consistant with scalar-vector products vectorization and elementates operand transpose. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D107307 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index e38bc64d4..11fcf9033 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -39,7 +39,7 @@ def quantized_matmul( @linalg_structured_op def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.K0, S.N0), accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): """Performs a matrix-matrix-transpose multiplication of two 4D inputs. @@ -54,7 +54,7 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), """ domain(D.m, D.n, D.m0, D.n0, D.k, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.k0, D.n0]) @linalg_structured_op def batch_matmul( From d05794c52504d46069217e92f423be24bead38d4 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 6 Aug 2021 04:10:03 +0000 Subject: [PATCH 091/915] [mlir][python] Make a number of imports relative. Avoiding absolute imports allows the code to be relocatable (which is used for out of tree integrations). Differential Revision: https://reviews.llvm.org/D107617 --- mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py | 2 +- .../mlir/dialects/linalg/opdsl/lang/comprehension.py | 3 +-- mlir/python/mlir/dialects/linalg/opdsl/lang/config.py | 3 +-- mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py | 2 +- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 8 ++++---- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py index 6db3bcfcc..9c1bb3342 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -53,7 +53,7 @@ from typing import Callable, Dict, Optional, Tuple, Union -from mlir import ir as _ir +from ..... import ir as _ir __all__ = [ "AffineBuildState", diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 66d7510b6..f7bfa81c0 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -11,8 +11,7 @@ from typing import Any, Dict, List, Optional, Sequence, Set, Tuple from enum import Enum -from mlir import ir as _ir - +from ..... import ir as _ir from .affine import * from .scalar_expr import * from .types import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index f6d5248ea..fec41decb 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -13,8 +13,7 @@ from typing import Dict, Optional -from mlir import ir as _ir - +from ..... import ir as _ir from .comprehension import * from .yaml_helper import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 1b42b5767..047bde245 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -9,7 +9,7 @@ import inspect import threading -from mlir import ir +from ..... import ir from .comprehension import * from .config import * from .emitter import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 4568298d1..ea2da7151 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -4,10 +4,10 @@ from typing import Dict, Sequence -from mlir.ir import * -from mlir.dialects import linalg -from mlir.dialects import std -from mlir.dialects import math +from .....ir import * +from .... import linalg +from .... import std +from .... import math # TODO: resolve name collision for Linalg functionality that is injected inside # the _mlir.dialects.linalg directly via pybind. from ....._cext_loader import _cext From d71b691febad3b92d2e233bed9df6c70e79a1918 Mon Sep 17 00:00:00 2001 From: natashaknk Date: Tue, 10 Aug 2021 14:29:15 -0700 Subject: [PATCH 092/915] [mlir][tosa] Add quantized and unquantized versions for tosa.depthwise_conv2d lowering Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D107855 --- .../linalg/opdsl/ops/core_named_ops.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 11fcf9033..fc92c196a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -223,6 +223,43 @@ def conv_2d_nchw( ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) +def depthwise_conv2D_nchw( #TODO: Fix name + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) + + +def depthwise_conv2D_nchw_q( #TODO: Fix name + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) - cast(U, IZp)) * + (cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp))) + + @linalg_structured_op def pooling_nhwc_sum( I=TensorDef(T1, S.N, S.H, S.W, S.C), From c832f1f51920630bafd868c91e3b9b676537c00a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 11 Aug 2021 11:05:08 -0700 Subject: [PATCH 093/915] [mlir][tosa] Migrate tosa to more efficient linalg.conv Existing linalg.conv2d is not well optimized for performance. Changed to a version that is more aligned for optimziation. Include the corresponding transposes to use this optimized version. This also splits the conv and depthwise conv into separate implementations to avoid overly complex lowerings. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D107504 --- .../linalg/opdsl/ops/core_named_ops.py | 79 +++++++++---------- 1 file changed, 36 insertions(+), 43 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index fc92c196a..b9faeeb83 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -144,49 +144,39 @@ def dot( implements(ContractionOpInterface) C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) - @linalg_structured_op -def conv_2d_input_nhwc_filter_ohwi_poly( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), - K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC), - O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True), +def conv_2d_nchw( + I=TensorDef(T1, S.N, S.C, S.IH, S.IW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - """Performs a 2-D convolution. + """Performs 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic) - O[D.n, D.oh, D.ow, D.oc] += cast( - U, I[D.n, - D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, - D.ic]) * cast(U, K[D.oc, D.kh, D.kw, D.ic]) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op -def conv_2d_input_nhwc_filter_ohwi_poly_q( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), - K=TensorDef(T2, S.OC, S.KH, S.KW, S.IC), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.OC, output=True), +def conv_2d_nhwc_hwcf( + I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - """Performs a 2-D quantized convolution. + """Performs 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Includes zero point - adjustment for quantization. + them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.oc, D.ic) - O[D.n, D.oh, D.ow, D.oc] += ((cast( - U, I[D.n, - D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, - D.ic]) - cast(U, IZp)) * - (cast(U, K[D.oc, D.kh, D.kw, D.ic]) - cast(U, KZp))) - + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c + ]) * cast(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op def depthwise_conv_2d_input_nhwc_filter_hwc_poly( @@ -206,24 +196,27 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly( D.c]) * cast(U, K[D.kh, D.kw, D.c]) @linalg_structured_op -def conv_2d_nchw( - I=TensorDef(T1, S.N, S.C, S.IH, S.IW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True), +def conv_2d_nhwc_hwcf_q( + I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), strides=AttributeDef(S.SH, S.SW), dilations=AttributeDef(S.DH, S.DW)): - """Performs 2-D convolution. + """Performs 2-D convolution with zero point offsets. Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. """ - domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += (cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c + ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) - -def depthwise_conv2D_nchw( #TODO: Fix name +@linalg_structured_op +def depthwise_conv2D_nchw( I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), @@ -239,8 +232,8 @@ def depthwise_conv2D_nchw( #TODO: Fix name U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) - -def depthwise_conv2D_nchw_q( #TODO: Fix name +@linalg_structured_op +def depthwise_conv2D_nchw_q( I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), IZp=ScalarDef(I32), From c1a329437ec2b6ff8df2f37c0174400228dc3706 Mon Sep 17 00:00:00 2001 From: natashaknk Date: Thu, 12 Aug 2021 15:37:34 -0700 Subject: [PATCH 094/915] [mlir][tosa] Fix depthwise_conv2D strides/dilation and name Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D107997 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b9faeeb83..b2bd4ce57 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -216,7 +216,7 @@ def conv_2d_nhwc_hwcf_q( ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) @linalg_structured_op -def depthwise_conv2D_nchw( +def depthwise_conv2D_nhwc( I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), @@ -233,7 +233,7 @@ def depthwise_conv2D_nchw( D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv2D_nchw_q( +def depthwise_conv2D_nhwc_q( I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), IZp=ScalarDef(I32), From ecbad4f7b8cf3b11e42a8982224e1f354d94e7b7 Mon Sep 17 00:00:00 2001 From: Robert Suderman Date: Mon, 16 Aug 2021 11:46:58 -0700 Subject: [PATCH 095/915] [mlir][linalg] Clear unused linalg tc operations These operations are not lowered to from any source dialect and are only used for redundant tests. Removing these named ops, along with their associated tests, will make migration to YAML operations much more convenient. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D107993 --- .../linalg/opdsl/ops/core_named_ops.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b2bd4ce57..21ca35bf1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -177,24 +177,6 @@ def conv_2d_nhwc_hwcf( O[D.n, D.oh, D.ow, D.f] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c ]) * cast(U, K[D.kh, D.kw, D.c, D.f]) - -@linalg_structured_op -def depthwise_conv_2d_input_nhwc_filter_hwc_poly( - I=TensorDef(T1, S.N, S.IH, S.IW, S.C), - K=TensorDef(T2, S.KH, S.KW, S.C), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c]) - @linalg_structured_op def conv_2d_nhwc_hwcf_q( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), From e667b52d81b2ee02fcdfbbde1763101053395f94 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 16 Aug 2021 13:47:00 -0700 Subject: [PATCH 096/915] [mlir][tosa] Fixed depthwise conv parallel/reduction indices order Reduction axis should come after all parallel axis to work with vectorization. Reviewed By: NatashaKnk Differential Revision: https://reviews.llvm.org/D108005 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 21ca35bf1..0590e6721 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -209,7 +209,7 @@ def depthwise_conv2D_nhwc( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) @@ -228,7 +228,7 @@ def depthwise_conv2D_nhwc_q( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.ic, D.cm) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - cast(U, IZp)) * From 543326c650b628c006d2deb58d0d7464fcb3c851 Mon Sep 17 00:00:00 2001 From: John Demme Date: Mon, 16 Aug 2021 22:37:14 -0700 Subject: [PATCH 097/915] [MLIR] [Python] Allow 'operation.parent' to return 'None' This is more Pythonic and better matches the C++ and C APIs. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D108183 --- mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++------ mlir/lib/Bindings/Python/IRModule.h | 3 ++- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1f8155000..3e927ceec 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -868,22 +868,23 @@ py::object PyOperationBase::getAsm(bool binary, return fileObject.attr("getvalue")(); } -PyOperationRef PyOperation::getParentOperation() { +llvm::Optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) - throw SetPyError(PyExc_ValueError, "Operation has no parent."); + return {}; return PyOperation::forOperation(getContext(), operation); } PyBlock PyOperation::getBlock() { checkValid(); - PyOperationRef parentOperation = getParentOperation(); + llvm::Optional parentOperation = getParentOperation(); MlirBlock block = mlirOperationGetBlock(get()); assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); - return PyBlock{std::move(parentOperation), block}; + assert(parentOperation && "Operation has no parent"); + return PyBlock{std::move(*parentOperation), block}; } py::object PyOperation::getCapsule() { @@ -2121,8 +2122,11 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) .def_property_readonly("parent", - [](PyOperation &self) { - return self.getParentOperation().getObject(); + [](PyOperation &self) -> py::object { + auto parent = self.getParentOperation(); + if (parent) + return parent->getObject(); + return py::none(); }) .def("erase", &PyOperation::erase) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 79c480e94..9d217c872 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -18,6 +18,7 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" namespace mlir { namespace python { @@ -452,7 +453,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Gets the parent operation or raises an exception if the operation has /// no parent. - PyOperationRef getParentOperation(); + llvm::Optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. pybind11::object getCapsule(); From a85f30513f33e3b03b5a61b4c0d46e0f16514bcb Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 19 Aug 2021 00:02:09 -0700 Subject: [PATCH 098/915] [MLIR] [Python] Add `owner` to `mlir.ir.Block` Provides a way for python users to access the owning Operation from a Block. --- mlir/lib/Bindings/Python/IRCore.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3e927ceec..d6305e7f4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2200,6 +2200,12 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyBlock. //---------------------------------------------------------------------------- py::class_(m, "Block") + .def_property_readonly( + "owner", + [](PyBlock &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the owning operation of this block.") .def_property_readonly( "arguments", [](PyBlock &self) { From 253cc3a013f367d6d66130d996d44b6aa7f35c79 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Fri, 20 Aug 2021 12:35:09 +0200 Subject: [PATCH 099/915] [mlir][linalg] Fix __repr__ implementation in const from opdsl Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D108369 --- mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index f7bfa81c0..f54d2a585 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -405,7 +405,7 @@ def to_scalar_expression(self) -> ScalarExpression: return ScalarConst(self.value).expr() def __repr__(self): - return f"const({self.type_var}, {self.value})" + return f"const({self.value})" class index(TensorExpression): From c581866e686f8c10f0e00d81a87f1ec2997bfb2d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 12 Aug 2021 16:20:56 -0700 Subject: [PATCH 100/915] [mlir][linalg] Finish refactor of TC ops to YAML Multiple operations were still defined as TC ops that had equivalent versions as YAML operations. Reducing to a single compilation path guarantees that frontends can lower to their equivalent operations without missing the optimized fastpath. Some operations are maintained purely for testing purposes (mainly conv{1,2,3}D as they are included as sole tests in the vectorizaiton transforms. Differential Revision: https://reviews.llvm.org/D108169 --- .../linalg/opdsl/ops/core_named_ops.py | 120 ++++++++++++++++-- 1 file changed, 109 insertions(+), 11 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 0590e6721..29b1397d1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -145,21 +145,63 @@ def dot( C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) @linalg_structured_op -def conv_2d_nchw( - I=TensorDef(T1, S.N, S.C, S.IH, S.IW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): - """Performs 2-D convolution. +def conv_1d( + I=TensorDef(T1, S.IW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True)): + """Performs 1-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.ow, D.kw) + O[D.ow] += cast( + U, I[D.ow + D.kw]) * cast(U, K[D.kw]) + +@linalg_structured_op +def conv_2d( + I=TensorDef(T1, S.IH, S.IW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True)): + """Performs 2-D convolution with no channels. Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ - domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + domain(D.oh, D.ow, D.kh, D.kw) + O[D.oh, D.ow] += cast( + U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) + +@linalg_structured_op +def conv_3d( + I=TensorDef(T1, S.ID, S.IH, S.IW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): + """Performs 3-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) + O[D.od, D.oh, D.ow] += cast( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw]) + +@linalg_structured_op +def conv_1d_nwc_wcf( + I=TensorDef(T1, S.N, S.IW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=AttributeDef(S.SW), + dilations=AttributeDef(S.DW)): + """Performs 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.ow, D.f, D.kw, D.c) + O[D.n, D.ow, D.f] += cast( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c + ]) * cast(U, K[D.kw, D.c, D.f]) @linalg_structured_op def conv_2d_nhwc_hwcf( @@ -177,6 +219,7 @@ def conv_2d_nhwc_hwcf( O[D.n, D.oh, D.ow, D.f] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c ]) * cast(U, K[D.kh, D.kw, D.c, D.f]) + @linalg_structured_op def conv_2d_nhwc_hwcf_q( I=TensorDef(T1, S.N, S.IH, S.IW, S.C), @@ -197,6 +240,61 @@ def conv_2d_nhwc_hwcf_q( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) +@linalg_structured_op +def conv_3d_ndhwc_dhwcf( + I=TensorDef(T1, S.N, S.ID, S.IH, S.IW, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=AttributeDef(S.SD, S.SH, S.SW), + dilations=AttributeDef(S.DD, S.DH, S.DW)): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += cast( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c + ]) * cast(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + +@linalg_structured_op +def depthwise_conv2D_nhw( + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most dpethwise convolutions. + """ + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) * cast(U, K[D.kh, D.kw, D.ic]) + +@linalg_structured_op +def depthwise_conv2D_nhw_q( + I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += ( + (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.ic]) - cast(U, IZp)) * + (cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp))) + @linalg_structured_op def depthwise_conv2D_nhwc( I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), From 9c19afaf5f9c241a5fc8a3c4323567aab457598a Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 22 Aug 2021 13:43:55 -0700 Subject: [PATCH 101/915] [mlir][python] Makes C++ extension code relocatable by way of a macro. * Resolves a TODO by making this configurable by downstreams. * This seems to be the last thing allowing full use of the Python bindings as a library within another project (i.e. be embedding them). Differential Revision: https://reviews.llvm.org/D108523 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 49 ++++++++++++++----- .../mlir/Bindings/Python/PybindAdaptors.h | 30 +++++------- 2 files changed, 50 insertions(+), 29 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 882f73d84..7fcfd028b 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -30,19 +30,44 @@ #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" -#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR "mlir.ir.AffineExpr._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr" +// The 'mlir' Python package is relocatable and supports co-existing in multiple +// projects. Each project must define its outer package prefix with this define +// in order to provide proper isolation and local name resolution. +// The default is for the upstream "import mlir" package layout. +// Note that this prefix is internally stringified, allowing it to be passed +// unquoted on the compiler command line without shell quote escaping issues. +#ifndef MLIR_PYTHON_PACKAGE_PREFIX +#define MLIR_PYTHON_PACKAGE_PREFIX mlir. +#endif + +// Makes a fully-qualified name relative to the MLIR python package. +#define MLIR_PYTHON_STRINGIZE(s) #s +#define MLIR_PYTHON_STRINGIZE_ARG(arg) MLIR_PYTHON_STRINGIZE(arg) +#define MAKE_MLIR_PYTHON_QUALNAME(local) \ + MLIR_PYTHON_STRINGIZE_ARG(MLIR_PYTHON_PACKAGE_PREFIX) local + +#define MLIR_PYTHON_CAPSULE_AFFINE_EXPR \ + MAKE_MLIR_PYTHON_QUALNAME("ir.AffineExpr._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_AFFINE_MAP \ + MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_ATTRIBUTE \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_CONTEXT \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") #define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \ - "mlir.execution_engine.ExecutionEngine._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_OPERATION "mlir.ir.Operation._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_TYPE "mlir.ir.Type._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_PASS_MANAGER "mlir.passmanager.PassManager._CAPIPtr" -#define MLIR_PYTHON_CAPSULE_VALUE "mlir.ir.Value._CAPIPtr" + MAKE_MLIR_PYTHON_QUALNAME("execution_engine.ExecutionEngine._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_INTEGER_SET \ + MAKE_MLIR_PYTHON_QUALNAME("ir.IntegerSet._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_LOCATION \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Location._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_MODULE \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Module._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_OPERATION \ + MAKE_MLIR_PYTHON_QUALNAME("ir.Operation._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPE MAKE_MLIR_PYTHON_QUALNAME("ir.Type._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ + MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index db8769d3c..61b982193 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -30,10 +30,6 @@ namespace py = pybind11; -// TODO: Move this to Interop.h and make it externally configurable/use it -// consistently to locate the "import mlir" top-level. -#define MLIR_PYTHON_PACKAGE_PREFIX "mlir." - // Raw CAPI type casters need to be declared before use, so always include them // first. namespace pybind11 { @@ -76,7 +72,7 @@ struct type_caster { static handle cast(MlirAffineMap v, return_value_policy, handle) { py::object capsule = py::reinterpret_steal(mlirPythonAffineMapToCapsule(v)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("AffineMap") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -98,7 +94,7 @@ struct type_caster { static handle cast(MlirAttribute v, return_value_policy, handle) { py::object capsule = py::reinterpret_steal(mlirPythonAttributeToCapsule(v)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -115,7 +111,7 @@ struct type_caster { // TODO: This raises an error of "No current context" currently. // Update the implementation to pretty-print the helpful error that the // core implementations print in this case. - src = py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Context") .attr("current"); } @@ -144,7 +140,7 @@ struct type_caster { static handle cast(MlirLocation v, return_value_policy, handle) { py::object capsule = py::reinterpret_steal(mlirPythonLocationToCapsule(v)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Location") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -166,7 +162,7 @@ struct type_caster { static handle cast(MlirModule v, return_value_policy, handle) { py::object capsule = py::reinterpret_steal(mlirPythonModuleToCapsule(v)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Module") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -190,7 +186,7 @@ struct type_caster { return py::none(); py::object capsule = py::reinterpret_steal(mlirPythonOperationToCapsule(v)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Operation") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -226,7 +222,7 @@ struct type_caster { static handle cast(MlirType t, return_value_policy, handle) { py::object capsule = py::reinterpret_steal(mlirPythonTypeToCapsule(t)); - return py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Type") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) .release(); @@ -266,7 +262,7 @@ class pure_subclass { } template - pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + pure_subclass &def(const char *name, Func &&f, const Extra &... extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(py::none()), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -276,7 +272,7 @@ class pure_subclass { template pure_subclass &def_property_readonly(const char *name, Func &&f, - const Extra &...extra) { + const Extra &... extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(py::none()), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -288,7 +284,7 @@ class pure_subclass { template pure_subclass &def_staticmethod(const char *name, Func &&f, - const Extra &...extra) { + const Extra &... extra) { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); @@ -301,7 +297,7 @@ class pure_subclass { template pure_subclass &def_classmethod(const char *name, Func &&f, - const Extra &...extra) { + const Extra &... extra) { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); @@ -329,7 +325,7 @@ class mlir_attribute_subclass : public pure_subclass { IsAFunctionTy isaFunction) : mlir_attribute_subclass( scope, attrClassName, isaFunction, - py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir") + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute")) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must @@ -381,7 +377,7 @@ class mlir_type_subclass : public pure_subclass { IsAFunctionTy isaFunction) : mlir_type_subclass( scope, typeClassName, isaFunction, - py::module::import(MLIR_PYTHON_PACKAGE_PREFIX "ir").attr("Type")) {} + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module From 7a9caaa3f34899e61dd7722243ed11f5e3f26c1a Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 22 Aug 2021 16:54:10 -0700 Subject: [PATCH 102/915] [mlir][linalg] Add script to update the LinalgNamedStructuredOps.yaml. nfc Also adds banners to the files with update instructions. Differential Revision: https://reviews.llvm.org/D108529 --- mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py index 05c06e737..bacc0c302 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -81,7 +81,6 @@ def main(args): # Print. if args.format == "yaml": - print("# Auto-generated file. Do not edit!") print(yaml_dump_all(configs)) elif args.format == "repr": for config in configs: From ac8657573c4a0bc7a5907d2e2e2e585188379b19 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 22 Aug 2021 15:11:42 -0700 Subject: [PATCH 103/915] [mlir] Add op for NCHW conv2d. * This is the native data layout for PyTorch and npcomp was using the prior version before cleanup. Differential Revision: https://reviews.llvm.org/D108527 --- .../linalg/opdsl/ops/core_named_ops.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 29b1397d1..38db29442 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -212,6 +212,10 @@ def conv_2d_nhwc_hwcf( dilations=AttributeDef(S.DH, S.DW)): """Performs 2-D convolution. + Layout: + * Input: NHWC. + * Kernel: HWCF. + Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ @@ -231,6 +235,10 @@ def conv_2d_nhwc_hwcf_q( dilations=AttributeDef(S.DH, S.DW)): """Performs 2-D convolution with zero point offsets. + Layout: + * Input: NHWC. + * Kernel: HWCF. + Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. @@ -240,6 +248,27 @@ def conv_2d_nhwc_hwcf_q( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) +@linalg_structured_op +def conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.IH, S.IW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW + ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op def conv_3d_ndhwc_dhwcf( I=TensorDef(T1, S.N, S.ID, S.IH, S.IW, S.C), From e923439bab86e2af9ecb589c3bc00af6d4fe9c49 Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Sun, 29 Aug 2021 14:22:24 -0700 Subject: [PATCH 104/915] [SymbolRefAttr] Revise SymbolRefAttr to hold a StringAttr. SymbolRefAttr is fundamentally a base string plus a sequence of nested references. Instead of storing the string data as a copies StringRef, store it as an already-uniqued StringAttr. This makes a lot of things simpler and more efficient because: 1) references to the symbol are already stored as StringAttr's: there is no need to copy the string data into MLIRContext multiple times. 2) This allows pointer comparisons instead of string comparisons (or redundant uniquing) within SymbolTable.cpp. 3) This allows SymbolTable to hold a DenseMap instead of a StringMap (which again copies the string data and slows lookup). This is a moderately invasive patch, so I kept a lot of compatibility APIs around. It would be nice to explore changing getName() to return a StringAttr for example (right now you have to use getNameAttr()), and eliminate things like the StringRef version of getSymbol. Differential Revision: https://reviews.llvm.org/D108899 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 3ec1b73c0..4ae54d4ca 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -212,15 +212,16 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); + auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); + return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference()); + return wrap(unwrap(attr).cast().getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference()); + return wrap(unwrap(attr).cast().getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { From ee3bd9c01a80125da94219edb8f23d1df4b84541 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 28 Aug 2021 20:15:51 -0700 Subject: [PATCH 105/915] [mlir][python] Extend C/Python API to be usable for CFG construction. * It is pretty clear that no one has tried this yet since it was both incomplete and broken. * Fixes a symbol hiding issues keeping even the generic builder from constructing an operation with successors. * Adds ODS support for successors. * Adds CAPI `mlirBlockGetParentRegion`, `mlirRegionEqual` + tests (and missing test for `mlirBlockGetParentOperation`). * Adds Python property: `Block.region`. * Adds Python methods: `Block.create_before` and `Block.create_after`. * Adds Python property: `InsertionPoint.block`. * Adds new blocks.py test to verify a plausible CFG construction case. Differential Revision: https://reviews.llvm.org/D108898 --- mlir/include/mlir-c/IR.h | 7 +++++ mlir/lib/Bindings/Python/IRCore.cpp | 47 +++++++++++++++++++++++++++-- mlir/lib/CAPI/IR/IR.cpp | 8 +++++ 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 6924fa88d..ebc3ada60 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -447,6 +447,10 @@ MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region); /// Checks whether a region is null. static inline bool mlirRegionIsNull(MlirRegion region) { return !region.ptr; } +/// Checks whether two region handles point to the same region. This does not +/// perform deep comparison. +MLIR_CAPI_EXPORTED bool mlirRegionEqual(MlirRegion region, MlirRegion other); + /// Gets the first block in the region. MLIR_CAPI_EXPORTED MlirBlock mlirRegionGetFirstBlock(MlirRegion region); @@ -496,6 +500,9 @@ MLIR_CAPI_EXPORTED bool mlirBlockEqual(MlirBlock block, MlirBlock other); /// Returns the closest surrounding operation that contains this block. MLIR_CAPI_EXPORTED MlirOperation mlirBlockGetParentOperation(MlirBlock); +/// Returns the region that contains this block. +MLIR_CAPI_EXPORTED MlirRegion mlirBlockGetParentRegion(MlirBlock block); + /// Returns the block immediately following the given block in its parent /// region. MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetNextInRegion(MlirBlock block); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d6305e7f4..7add4eb7b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -969,7 +969,6 @@ py::object PyOperation::create( } // Unpack/validate successors. if (successors) { - llvm::SmallVector mlirSuccessors; mlirSuccessors.reserve(successors->size()); for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. @@ -2206,6 +2205,13 @@ void mlir::python::populateIRCore(py::module &m) { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") + .def_property_readonly( + "region", + [](PyBlock &self) { + MlirRegion region = mlirBlockGetParentRegion(self.get()); + return PyRegion(self.getParentOperation(), region); + }, + "Returns the owning region of this block.") .def_property_readonly( "arguments", [](PyBlock &self) { @@ -2218,6 +2224,40 @@ void mlir::python::populateIRCore(py::module &m) { return PyOperationList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of operations.") + .def( + "create_before", + [](PyBlock &self, py::args pyArgTypes) { + self.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockBefore(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }, + "Creates and returns a new Block before this block " + "(with given argument types).") + .def( + "create_after", + [](PyBlock &self, py::args pyArgTypes) { + self.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirRegion region = mlirBlockGetParentRegion(self.get()); + mlirRegionInsertOwnedBlockAfter(region, self.get(), block); + return PyBlock(self.getParentOperation(), block); + }, + "Creates and returns a new Block after this block " + "(with given argument types).") .def( "__iter__", [](PyBlock &self) { @@ -2270,7 +2310,10 @@ void mlir::python::populateIRCore(py::module &m) { .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, py::arg("block"), "Inserts before the block terminator.") .def("insert", &PyInsertionPoint::insert, py::arg("operation"), - "Inserts an operation."); + "Inserts an operation.") + .def_property_readonly( + "block", [](PyInsertionPoint &self) { return self.getBlock(); }, + "Returns the block that this InsertionPoint points to."); //---------------------------------------------------------------------------- // Mapping of PyAttribute. diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 2721efde3..68037f0af 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -427,6 +427,10 @@ bool mlirOperationVerify(MlirOperation op) { MlirRegion mlirRegionCreate() { return wrap(new Region); } +bool mlirRegionEqual(MlirRegion region, MlirRegion other) { + return unwrap(region) == unwrap(other); +} + MlirBlock mlirRegionGetFirstBlock(MlirRegion region) { Region *cppRegion = unwrap(region); if (cppRegion->empty()) @@ -492,6 +496,10 @@ MlirOperation mlirBlockGetParentOperation(MlirBlock block) { return wrap(unwrap(block)->getParentOp()); } +MlirRegion mlirBlockGetParentRegion(MlirBlock block) { + return wrap(unwrap(block)->getParent()); +} + MlirBlock mlirBlockGetNextInRegion(MlirBlock block) { return wrap(unwrap(block)->getNextNode()); } From b706b8b795a6374f23a41fb23fc7d4abe835e8fe Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 23 Aug 2021 20:01:07 -0700 Subject: [PATCH 106/915] [mlir][python] Apply py::module_local() to all classes. * This allows multiple MLIR-API embedding downstreams to co-exist in the same process. * I believe this is the last thing needed to enable isolated embedding. Differential Revision: https://reviews.llvm.org/D108605 --- .../Bindings/Python/ExecutionEngineModule.cpp | 2 +- mlir/lib/Bindings/Python/IRAffine.cpp | 11 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 3 +- mlir/lib/Bindings/Python/IRCore.cpp | 52 +++++++++---------- mlir/lib/Bindings/Python/MainModule.cpp | 2 +- mlir/lib/Bindings/Python/Pass.cpp | 2 +- 6 files changed, 37 insertions(+), 35 deletions(-) diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 510e3f8dd..765ff2826 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "ExecutionEngine") + py::class_(m, "ExecutionEngine", py::module_local()) .def(py::init<>([](MlirModule module, int optLevel, const std::vector &sharedLibPaths) { llvm::SmallVector libPaths; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 0a2a5666a..5314badba 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -97,7 +97,7 @@ class PyConcreteAffineExpr : public BaseTy { } static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init()); DerivedTy::bindDerived(cls); } @@ -367,7 +367,8 @@ class PyIntegerSetConstraint { bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint") + py::class_(m, "IntegerSetConstraint", + py::module_local()) .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); } @@ -427,7 +428,7 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr") + py::class_(m, "AffineExpr", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) @@ -515,7 +516,7 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- - py::class_(m, "AffineMap") + py::class_(m, "AffineMap", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) @@ -686,7 +687,7 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet") + py::class_(m, "IntegerSet", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 0af762d93..bb4b5f4f0 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -65,7 +65,8 @@ class PyArrayAttribute : public PyConcreteAttribute { } static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator") + py::class_(m, "ArrayAttributeIterator", + py::module_local()) .def("__iter__", &PyArrayAttributeIterator::dunderIter) .def("__next__", &PyArrayAttributeIterator::dunderNext); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7add4eb7b..8672b772e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -163,7 +163,7 @@ struct PyGlobalDebugFlag { static void bind(py::module &m) { // Debug flags. - py::class_(m, "_GlobalDebug") + py::class_(m, "_GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); } @@ -192,7 +192,7 @@ class PyRegionIterator { } static void bind(py::module &m) { - py::class_(m, "RegionIterator") + py::class_(m, "RegionIterator", py::module_local()) .def("__iter__", &PyRegionIterator::dunderIter) .def("__next__", &PyRegionIterator::dunderNext); } @@ -224,7 +224,7 @@ class PyRegionList { } static void bind(py::module &m) { - py::class_(m, "RegionSequence") + py::class_(m, "RegionSequence", py::module_local()) .def("__len__", &PyRegionList::dunderLen) .def("__getitem__", &PyRegionList::dunderGetItem); } @@ -252,7 +252,7 @@ class PyBlockIterator { } static void bind(py::module &m) { - py::class_(m, "BlockIterator") + py::class_(m, "BlockIterator", py::module_local()) .def("__iter__", &PyBlockIterator::dunderIter) .def("__next__", &PyBlockIterator::dunderNext); } @@ -317,7 +317,7 @@ class PyBlockList { } static void bind(py::module &m) { - py::class_(m, "BlockList") + py::class_(m, "BlockList", py::module_local()) .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) @@ -349,7 +349,7 @@ class PyOperationIterator { } static void bind(py::module &m) { - py::class_(m, "OperationIterator") + py::class_(m, "OperationIterator", py::module_local()) .def("__iter__", &PyOperationIterator::dunderIter) .def("__next__", &PyOperationIterator::dunderNext); } @@ -405,7 +405,7 @@ class PyOperationList { } static void bind(py::module &m) { - py::class_(m, "OperationList") + py::class_(m, "OperationList", py::module_local()) .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) .def("__len__", &PyOperationList::dunderLen); @@ -1539,7 +1539,7 @@ class PyConcreteValue : public PyValue { /// Binds the Python module objects to functions of this class. static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init(), py::keep_alive<0, 1>()); DerivedTy::bindDerived(cls); } @@ -1617,7 +1617,7 @@ class PyBlockArgumentList { /// Defines a Python class in the bindings. static void bind(py::module &m) { - py::class_(m, "BlockArgumentList") + py::class_(m, "BlockArgumentList", py::module_local()) .def("__len__", &PyBlockArgumentList::dunderLen) .def("__getitem__", &PyBlockArgumentList::dunderGetItem); } @@ -1764,7 +1764,7 @@ class PyOpAttributeMap { } static void bind(py::module &m) { - py::class_(m, "OpAttributeMap") + py::class_(m, "OpAttributeMap", py::module_local()) .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) @@ -1787,7 +1787,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of MlirContext. //---------------------------------------------------------------------------- - py::class_(m, "Context") + py::class_(m, "Context", py::module_local()) .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", @@ -1851,7 +1851,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor") + py::class_(m, "DialectDescriptor", py::module_local()) .def_property_readonly("namespace", [](PyDialectDescriptor &self) { MlirStringRef ns = @@ -1869,7 +1869,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialects //---------------------------------------------------------------------------- - py::class_(m, "Dialects") + py::class_(m, "Dialects", py::module_local()) .def("__getitem__", [=](PyDialects &self, std::string keyName) { MlirDialect dialect = @@ -1889,7 +1889,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- - py::class_(m, "Dialect") + py::class_(m, "Dialect", py::module_local()) .def(py::init(), "descriptor") .def_property_readonly( "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) @@ -1904,7 +1904,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- - py::class_(m, "Location") + py::class_(m, "Location", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) @@ -1956,7 +1956,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- - py::class_(m, "Module") + py::class_(m, "Module", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( @@ -2025,7 +2025,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "_OperationBase") + py::class_(m, "_OperationBase", py::module_local()) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); @@ -2112,7 +2112,7 @@ void mlir::python::populateIRCore(py::module &m) { "Verify the operation and return true if it passes, false if it " "fails."); - py::class_(m, "Operation") + py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), py::arg("results") = py::none(), py::arg("operands") = py::none(), @@ -2149,7 +2149,7 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly("opview", &PyOperation::createOpView); auto opViewClass = - py::class_(m, "OpView") + py::class_(m, "OpView", py::module_local()) .def(py::init()) .def_property_readonly("operation", &PyOpView::getOperationObject) .def_property_readonly( @@ -2174,7 +2174,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyRegion. //---------------------------------------------------------------------------- - py::class_(m, "Region") + py::class_(m, "Region", py::module_local()) .def_property_readonly( "blocks", [](PyRegion &self) { @@ -2198,7 +2198,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- - py::class_(m, "Block") + py::class_(m, "Block", py::module_local()) .def_property_readonly( "owner", [](PyBlock &self) { @@ -2288,7 +2288,7 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- - py::class_(m, "InsertionPoint") + py::class_(m, "InsertionPoint", py::module_local()) .def(py::init(), py::arg("block"), "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) @@ -2318,7 +2318,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - py::class_(m, "Attribute") + py::class_(m, "Attribute", py::module_local()) // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. .def(py::init(), py::arg("cast_from_type"), @@ -2389,7 +2389,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- - py::class_(m, "NamedAttribute") + py::class_(m, "NamedAttribute", py::module_local()) .def("__repr__", [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; @@ -2425,7 +2425,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - py::class_(m, "Type") + py::class_(m, "Type", py::module_local()) // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. .def(py::init(), py::arg("cast_from_type"), @@ -2479,7 +2479,7 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - py::class_(m, "Value") + py::class_(m, "Value", py::module_local()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 073ac9037..cbade532e 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -26,7 +26,7 @@ using namespace mlir::python; PYBIND11_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_(m, "_Globals") + py::class_(m, "_Globals", py::module_local()) .def_property("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index f2433573b..6aa1c651c 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -55,7 +55,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "PassManager") + py::class_(m, "PassManager", py::module_local()) .def(py::init<>([](DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); From c0d7a47b7ff9b0b10f5aceace0772cee561e73e0 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 1 Sep 2021 16:16:35 -0700 Subject: [PATCH 107/915] [mlir][capi] Add NameLoc Add method to get NameLoc. Treat null child location as unknown to avoid needing to create UnknownLoc in C API where child loc is not needed. Differential Revision: https://reviews.llvm.org/D108678 --- mlir/include/mlir-c/IR.h | 7 +++++++ mlir/lib/CAPI/IR/IR.cpp | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index ebc3ada60..d875e3807 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -162,6 +162,13 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet( MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller); +/// Creates a name location owned by the given context. Providing null location +/// for childLoc is allowed and if childLoc is null location, then the behavior +/// is the same as having unknown child location. +MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, + MlirStringRef name, + MlirLocation childLoc); + /// Creates a location with unknown position owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 68037f0af..bbadc351d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" #include "mlir/IR/Verifier.h" @@ -131,6 +132,15 @@ MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); } +MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, + MlirLocation childLoc) { + if (mlirLocationIsNull(childLoc)) + return wrap( + Location(NameLoc::get(Identifier::get(unwrap(name), unwrap(context))))); + return wrap(Location(NameLoc::get( + Identifier::get(unwrap(name), unwrap(context)), unwrap(childLoc)))); +} + MlirLocation mlirLocationUnknownGet(MlirContext context) { return wrap(Location(UnknownLoc::get(unwrap(context)))); } From 59394e71d1ab826b14952c75fe9a19c271d06a61 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 3 Sep 2021 00:37:00 +0000 Subject: [PATCH 108/915] [mlir][python] Simplify python extension loading. * Now that packaging has stabilized, removes old mechanisms for loading extensions, preferring direct importing. * Removes _cext_loader.py, _dlloader.py as unnecessary. * Fixes the path where the CAPI dll is written on Windows. This enables that path of least resistance loading behavior to work with no further drama (see: https://bugs.python.org/issue36085). * With this patch, `ninja check-mlir` on Windows with Python bindings works for me, modulo some failures that are actually due to a couple of pre-existing Windows bugs. I think this is the first time the Windows Python bindings have worked upstream. * Downstream changes needed: * If downstreams are using the now removed `load_extension`, `reexport_cext`, etc, then those should be replaced with normal import statements as done in this patch. Reviewed By: jdd, aartbik Differential Revision: https://reviews.llvm.org/D108489 --- mlir/lib/Bindings/Python/IRModule.cpp | 5 ++ mlir/python/CMakeLists.txt | 2 - mlir/python/mlir/_cext_loader.py | 57 ------------------ mlir/python/mlir/_dlloader.py | 59 ------------------- mlir/python/mlir/_mlir_libs/__init__.py | 14 ----- .../mlir/all_passes_registration/__init__.py | 5 +- mlir/python/mlir/conversions/__init__.py | 3 +- mlir/python/mlir/dialects/_linalg_ops_ext.py | 36 +++++------ mlir/python/mlir/dialects/_ods_common.py | 5 +- .../dialects/async_dialect/passes/__init__.py | 3 +- .../mlir/dialects/gpu/passes/__init__.py | 3 +- .../dialects/linalg/opdsl/lang/emitter.py | 10 ++-- .../mlir/dialects/linalg/passes/__init__.py | 3 +- mlir/python/mlir/dialects/sparse_tensor.py | 10 +--- mlir/python/mlir/execution_engine.py | 3 +- mlir/python/mlir/ir.py | 7 +-- mlir/python/mlir/passmanager.py | 5 +- mlir/python/mlir/transforms/__init__.py | 3 +- 18 files changed, 39 insertions(+), 194 deletions(-) delete mode 100644 mlir/python/mlir/_cext_loader.py delete mode 100644 mlir/python/mlir/_dlloader.py diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 08ce06da8..9f853eb92 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -12,6 +12,8 @@ #include +#include "mlir-c/Bindings/Python/Interop.h" + namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -25,6 +27,9 @@ PyGlobals *PyGlobals::instance = nullptr; PyGlobals::PyGlobals() { assert(!instance && "PyGlobals already constructed"); instance = this; + // The default search path include {mlir.}dialects, where {mlir.} is the + // package prefix configured at compile time. + dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); } PyGlobals::~PyGlobals() { instance = nullptr; } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9f66aa9c2..506d8ead2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -20,8 +20,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonSources SOURCES - _cext_loader.py - _dlloader.py _mlir_libs/__init__.py ir.py passmanager.py diff --git a/mlir/python/mlir/_cext_loader.py b/mlir/python/mlir/_cext_loader.py deleted file mode 100644 index 5f2de7f00..000000000 --- a/mlir/python/mlir/_cext_loader.py +++ /dev/null @@ -1,57 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Common module for looking up and manipulating C-Extensions.""" - -# The normal layout is to have a nested _mlir_libs package that contains -# all native libraries and extensions. If that exists, use it, but also fallback -# to old behavior where extensions were at the top level as loose libraries. -# TODO: Remove the fallback once downstreams adapt. -try: - from ._mlir_libs import * - # TODO: Remove these aliases once everything migrates - _preload_dependency = preload_dependency - _load_extension = load_extension -except ModuleNotFoundError: - # Assume that we are in-tree. - # The _dlloader takes care of platform specific setup before we try to - # load a shared library. - # TODO: Remove _dlloader once all consolidated on the _mlir_libs approach. - from ._dlloader import preload_dependency - - def load_extension(name): - import importlib - return importlib.import_module(name) # i.e. '_mlir' at the top level - -preload_dependency("MLIRPythonCAPI") - -# Expose the corresponding C-Extension module with a well-known name at this -# top-level module. This allows relative imports like the following to -# function: -# from .._cext_loader import _cext -# This reduces coupling, allowing embedding of the python sources into another -# project that can just vary based on this top-level loader module. -_cext = load_extension("_mlir") - - -def _reexport_cext(cext_module_name, target_module_name): - """Re-exports a named sub-module of the C-Extension into another module. - - Typically: - from ._cext_loader import _reexport_cext - _reexport_cext("ir", __name__) - del _reexport_cext - """ - import sys - target_module = sys.modules[target_module_name] - submodule_names = cext_module_name.split(".") - source_module = _cext - for submodule_name in submodule_names: - source_module = getattr(source_module, submodule_name) - for attr_name in dir(source_module): - if not attr_name.startswith("__"): - setattr(target_module, attr_name, getattr(source_module, attr_name)) - - -# Add our 'dialects' parent module to the search path for implementations. -_cext.globals.append_dialect_search_prefix("mlir.dialects") diff --git a/mlir/python/mlir/_dlloader.py b/mlir/python/mlir/_dlloader.py deleted file mode 100644 index 454a7b7f1..000000000 --- a/mlir/python/mlir/_dlloader.py +++ /dev/null @@ -1,59 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -import os -import platform - -_is_windows = platform.system() == "Windows" -_this_directory = os.path.dirname(__file__) - -# The standard LLVM build/install tree for Windows is laid out as: -# bin/ -# MLIRPublicAPI.dll -# python/ -# _mlir.*.pyd (dll extension) -# mlir/ -# _dlloader.py (this file) -# First check the python/ directory level for DLLs co-located with the pyd -# file, and then fall back to searching the bin/ directory. -# TODO: This should be configurable at some point. -_dll_search_path = [ - os.path.join(_this_directory, ".."), - os.path.join(_this_directory, "..", "..", "bin"), -] - -# Stash loaded DLLs to keep them alive. -_loaded_dlls = [] - -def preload_dependency(public_name): - """Preloads a dylib by its soname or DLL name. - - On Windows and Linux, doing this prior to loading a dependency will populate - the library in the flat namespace so that a subsequent library that depend - on it will resolve to this preloaded version. - - On OSX, resolution is completely path based so this facility no-ops. On - Linux, as long as RPATHs are setup properly, resolution is path based but - this facility can still act as an escape hatch for relocatable distributions. - """ - if _is_windows: - _preload_dependency_windows(public_name) - - -def _preload_dependency_windows(public_name): - dll_basename = public_name + ".dll" - found_path = None - for search_dir in _dll_search_path: - candidate_path = os.path.join(search_dir, dll_basename) - if os.path.exists(candidate_path): - found_path = candidate_path - break - - if found_path is None: - raise RuntimeError( - f"Unable to find dependency DLL {dll_basename} in search " - f"path {_dll_search_path}") - - import ctypes - _loaded_dlls.append(ctypes.CDLL(found_path)) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 55139b2a8..4e2e5f453 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -4,24 +4,10 @@ from typing import Sequence -import importlib import os -__all__ = [ - "load_extension", - "preload_dependency", -] - _this_dir = os.path.dirname(__file__) -def load_extension(name): - return importlib.import_module(f".{name}", __package__) - - -def preload_dependency(public_name): - # TODO: Implement this hook to pre-load DLLs with ctypes on Windows. - pass - def get_lib_dirs() -> Sequence[str]: """Gets the lib directory for linking to shared libraries. diff --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py index cf3367cfe..aca557ab9 100644 --- a/mlir/python/mlir/all_passes_registration/__init__.py +++ b/mlir/python/mlir/all_passes_registration/__init__.py @@ -2,7 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._cext_loader import _load_extension - -_cextAllPasses = _load_extension("_mlirAllPassesRegistration") -del _load_extension +from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses diff --git a/mlir/python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py index 0989449a4..a6a9eb821 100644 --- a/mlir/python/mlir/conversions/__init__.py +++ b/mlir/python/mlir/conversions/__init__.py @@ -4,5 +4,4 @@ # Expose the corresponding C-Extension module with a well-known name at this # level. -from .._cext_loader import _load_extension -_cextConversions = _load_extension("_mlirConversions") +from .._mlir_libs import _mlirConversions as _cextConversions diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 656992cac..536096749 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -6,10 +6,7 @@ from typing import Optional, Sequence, Union from ..ir import * from ._ods_common import get_default_loc_context - # TODO: resolve name collision for Linalg functionality that is injected inside - # the _mlir.dialects.linalg directly via pybind. - from .._cext_loader import _cext - fill_builtin_region = _cext.dialects.linalg.fill_builtin_region + from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -29,12 +26,11 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None): results = [] if isa(RankedTensorType, output.type): results = [output.type] - op = self.build_generic( - results=results, - operands=[value, output], - attributes=None, - loc=loc, - ip=ip) + op = self.build_generic(results=results, + operands=[value, output], + attributes=None, + loc=loc, + ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) @@ -78,12 +74,11 @@ def __init__(self, attributes["static_sizes"] = ArrayAttr.get( [IntegerAttr.get(i64_type, s) for s in static_size_ints], context=context) - op = self.build_generic( - results=[result_type], - operands=operands, - attributes=attributes, - loc=loc, - ip=ip) + op = self.build_generic(results=[result_type], + operands=operands, + attributes=attributes, + loc=loc, + ip=ip) OpView.__init__(self, op) @@ -92,11 +87,10 @@ class StructuredOpMixin: def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): super().__init__( - self.build_generic( - results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip)) + self.build_generic(results=list(results), + operands=[list(inputs), list(outputs)], + loc=loc, + ip=ip)) def select_opview_mixin(parent_opview_cls): diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index d03044088..2fbf3545f 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -2,8 +2,9 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Re-export the parent _cext so that every level of the API can get it locally. -from .._cext_loader import _cext +# Provide a convenient name for sub-packages to resolve the main C-extension +# with a relative import. +from .._mlir_libs import _mlir as _cext __all__ = [ "equally_sized_accessor", diff --git a/mlir/python/mlir/dialects/async_dialect/passes/__init__.py b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py index 88a7b539c..851d56148 100644 --- a/mlir/python/mlir/dialects/async_dialect/passes/__init__.py +++ b/mlir/python/mlir/dialects/async_dialect/passes/__init__.py @@ -2,5 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...._cext_loader import _load_extension -_cextAsyncPasses = _load_extension("_mlirAsyncPasses") +from ...._mlir_libs import _mlirAsyncPasses as _cextAsyncPasses diff --git a/mlir/python/mlir/dialects/gpu/passes/__init__.py b/mlir/python/mlir/dialects/gpu/passes/__init__.py index dd28e91a4..9b1ef076a 100644 --- a/mlir/python/mlir/dialects/gpu/passes/__init__.py +++ b/mlir/python/mlir/dialects/gpu/passes/__init__.py @@ -2,5 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...._cext_loader import _load_extension -_cextGPUPasses = _load_extension("_mlirGPUPasses") +from ...._mlir_libs import _mlirGPUPasses as _cextGPUPasses diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index ea2da7151..b151a9ba9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -5,13 +5,11 @@ from typing import Dict, Sequence from .....ir import * +from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region + from .... import linalg from .... import std from .... import math -# TODO: resolve name collision for Linalg functionality that is injected inside -# the _mlir.dialects.linalg directly via pybind. -from ....._cext_loader import _cext -fill_builtin_region = _cext.dialects.linalg.fill_builtin_region from .scalar_expr import * from .config import * @@ -216,8 +214,8 @@ def expression(self, expr: ScalarExpression) -> Value: value_attr = Attribute.parse(expr.scalar_const.value) return std.ConstantOp(value_attr.type, value_attr).result elif expr.scalar_index: - dim_attr = IntegerAttr.get( - IntegerType.get_signless(64), expr.scalar_index.dim) + dim_attr = IntegerAttr.get(IntegerType.get_signless(64), + expr.scalar_index.dim) return linalg.IndexOp(IndexType.get(), dim_attr).result elif expr.scalar_apply: try: diff --git a/mlir/python/mlir/dialects/linalg/passes/__init__.py b/mlir/python/mlir/dialects/linalg/passes/__init__.py index 6555ad69a..0920e8ef4 100644 --- a/mlir/python/mlir/dialects/linalg/passes/__init__.py +++ b/mlir/python/mlir/dialects/linalg/passes/__init__.py @@ -2,5 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ...._cext_loader import _load_extension -_cextLinalgPasses = _load_extension("_mlirLinalgPasses") +from ...._mlir_libs import _mlirLinalgPasses as _cextLinalgPasses diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py index 59fd86021..4a89ef8ae 100644 --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -2,11 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._cext_loader import _reexport_cext -from .._cext_loader import _load_extension - -_reexport_cext("dialects.sparse_tensor", __name__) -_cextSparseTensorPasses = _load_extension("_mlirSparseTensorPasses") - -del _reexport_cext -del _load_extension +from .._mlir_libs._mlir.dialects.sparse_tensor import * +from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py index f3bcd0e0d..1c516ae5a 100644 --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Simply a wrapper around the extension module of the same name. -from ._cext_loader import load_extension -_execution_engine = load_extension("_mlirExecutionEngine") +from ._mlir_libs import _mlirExecutionEngine as _execution_engine import ctypes __all__ = [ diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 2b420511d..99e88ff74 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -2,8 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Simply a wrapper around the extension module of the same name. -from ._cext_loader import _reexport_cext -_reexport_cext("ir", __name__) -del _reexport_cext - +from ._mlir_libs._mlir.ir import * +from ._mlir_libs._mlir.ir import _GlobalDebug diff --git a/mlir/python/mlir/passmanager.py b/mlir/python/mlir/passmanager.py index 6b267b76e..22e86b879 100644 --- a/mlir/python/mlir/passmanager.py +++ b/mlir/python/mlir/passmanager.py @@ -2,7 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Simply a wrapper around the extension module of the same name. -from ._cext_loader import _reexport_cext -_reexport_cext("passmanager", __name__) -del _reexport_cext +from ._mlir_libs._mlir.passmanager import * diff --git a/mlir/python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py index 2149933d0..71ea17d7f 100644 --- a/mlir/python/mlir/transforms/__init__.py +++ b/mlir/python/mlir/transforms/__init__.py @@ -4,5 +4,4 @@ # Expose the corresponding C-Extension module with a well-known name at this # level. -from .._cext_loader import _load_extension -_cextTransforms = _load_extension("_mlirTransforms") +from .._mlir_libs import _mlirTransforms as _cextTransforms From 3fc9f260325d40b4d6de0e590a6078fc2ecbd5e4 Mon Sep 17 00:00:00 2001 From: Benoit Jacob Date: Mon, 13 Sep 2021 12:08:54 -0700 Subject: [PATCH 109/915] Reorder mmt4d shapes: * Revert https://reviews.llvm.org/D107307 so that both LHS and RHS have the same layout with K0 as the innermost dimension. * Continuing from https://reviews.llvm.org/D107003, move also 'K' to the outer side, so that now the inter-tile dimensions as all outer, and the intra-tile dimensions are all inner. Reviewed By: asaadaldien Differential Revision: https://reviews.llvm.org/D109692 --- .../python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 38db29442..fc37a2e8f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -39,7 +39,7 @@ def quantized_matmul( @linalg_structured_op def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.K0, S.N0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): """Performs a matrix-matrix-transpose multiplication of two 4D inputs. @@ -52,9 +52,9 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0. """ - domain(D.m, D.n, D.m0, D.n0, D.k, D.k0) + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.k0, D.n0]) + accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op def batch_matmul( From c055b941e25ffa82e390d14ed680ed47a3dec4cb Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 14 Sep 2021 21:55:54 +0000 Subject: [PATCH 110/915] [mlir] Apply py::module_local() to a few more classes. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D109776 --- mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 2 +- mlir/lib/Bindings/Python/IRModule.h | 5 +++-- mlir/lib/Bindings/Python/PybindUtils.h | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index faf240e1a..6afd0815d 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -20,7 +20,7 @@ void mlir::python::populateDialectSparseTensorSubmodule( py::module m, const py::module &irModule) { auto attributeClass = irModule.attr("Attribute"); - py::enum_(m, "DimLevelType") + py::enum_(m, "DimLevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9d217c872..702870487 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -678,7 +678,8 @@ class PyConcreteAttribute : public BaseTy { } static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol()); + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), + pybind11::module_local()); cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); DerivedTy::bindDerived(cls); } @@ -741,7 +742,7 @@ class PyConcreteType : public BaseTy { } static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); cls.def_static("isinstance", [](PyType &otherType) -> bool { return DerivedTy::isaFunction(otherType); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 7a9b8ecb9..9fecc7ce3 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -262,7 +262,8 @@ class Sliceable { /// Binds the indexing and length methods in the Python class. static void bind(pybind11::module &m) { - auto clazz = pybind11::class_(m, Derived::pyClassName) + auto clazz = pybind11::class_(m, Derived::pyClassName, + pybind11::module_local()) .def("__len__", &Sliceable::dunderLen) .def("__getitem__", &Sliceable::dunderGetItem) .def("__getitem__", &Sliceable::dunderGetItemSlice); From 30da9c9080097790b7fe6d1f818de7de24bc985f Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 16 Sep 2021 06:01:38 +0000 Subject: [PATCH 111/915] [mlir][OpDSL] Update op definitions to make shapes more concise (NFC). Express the input shape definitions of convolution and pooling operations in terms of the output shapes, filter shapes, strides, and dilations. Reviewed By: shabalin, rsuderman, stellaraccident Differential Revision: https://reviews.llvm.org/D109815 --- .../linalg/opdsl/ops/core_named_ops.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index fc37a2e8f..7a804b7a9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -146,7 +146,7 @@ def dot( @linalg_structured_op def conv_1d( - I=TensorDef(T1, S.IW), + I=TensorDef(T1, S.OW + S.KW), K=TensorDef(T2, S.KW), O=TensorDef(U, S.OW, output=True)): """Performs 1-D convolution with no channels. @@ -160,7 +160,7 @@ def conv_1d( @linalg_structured_op def conv_2d( - I=TensorDef(T1, S.IH, S.IW), + I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), K=TensorDef(T2, S.KH, S.KW), O=TensorDef(U, S.OH, S.OW, output=True)): """Performs 2-D convolution with no channels. @@ -174,7 +174,7 @@ def conv_2d( @linalg_structured_op def conv_3d( - I=TensorDef(T1, S.ID, S.IH, S.IW), + I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), K=TensorDef(T2, S.KD, S.KH, S.KW), O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): """Performs 3-D convolution with no channels. @@ -188,7 +188,7 @@ def conv_3d( @linalg_structured_op def conv_1d_nwc_wcf( - I=TensorDef(T1, S.N, S.IW, S.C), + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OW, S.F, output=True), strides=AttributeDef(S.SW), @@ -205,7 +205,7 @@ def conv_1d_nwc_wcf( @linalg_structured_op def conv_2d_nhwc_hwcf( - I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), strides=AttributeDef(S.SH, S.SW), @@ -226,7 +226,7 @@ def conv_2d_nhwc_hwcf( @linalg_structured_op def conv_2d_nhwc_hwcf_q( - I=TensorDef(T1, S.N, S.IH, S.IW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), IZp=ScalarDef(I32), KZp=ScalarDef(I32), @@ -250,7 +250,7 @@ def conv_2d_nhwc_hwcf_q( @linalg_structured_op def conv_2d_nchw_fchw( - I=TensorDef(T1, S.N, S.C, S.IH, S.IW), + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KH, S.KW), O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), strides=AttributeDef(S.SH, S.SW), @@ -271,7 +271,8 @@ def conv_2d_nchw_fchw( @linalg_structured_op def conv_3d_ndhwc_dhwcf( - I=TensorDef(T1, S.N, S.ID, S.IH, S.IW, S.C), + I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), @@ -288,7 +289,7 @@ def conv_3d_ndhwc_dhwcf( @linalg_structured_op def depthwise_conv2D_nhw( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), strides=AttributeDef(S.SH, S.SW), @@ -306,7 +307,7 @@ def depthwise_conv2D_nhw( @linalg_structured_op def depthwise_conv2D_nhw_q( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), IZp=ScalarDef(I32), KZp=ScalarDef(I32), @@ -326,7 +327,7 @@ def depthwise_conv2D_nhw_q( @linalg_structured_op def depthwise_conv2D_nhwc( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), strides=AttributeDef(S.SH, S.SW), @@ -343,7 +344,7 @@ def depthwise_conv2D_nhwc( @linalg_structured_op def depthwise_conv2D_nhwc_q( - I=TensorDef(T1, S.N, S.IH, S.IW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), IZp=ScalarDef(I32), KZp=ScalarDef(I32), @@ -364,7 +365,7 @@ def depthwise_conv2D_nhwc_q( @linalg_structured_op def pooling_nhwc_sum( - I=TensorDef(T1, S.N, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), @@ -381,7 +382,7 @@ def pooling_nhwc_sum( @linalg_structured_op def pooling_nhwc_max( - I=TensorDef(T1, S.N, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), @@ -398,7 +399,7 @@ def pooling_nhwc_max( @linalg_structured_op def pooling_nchw_max( - I=TensorDef(T1, S.N, S.C, S.H, S.W), + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), strides=AttributeDef(S.SH, S.SW), @@ -415,7 +416,7 @@ def pooling_nchw_max( @linalg_structured_op def pooling_nhwc_min( - I=TensorDef(T1, S.N, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SH, S.SW), @@ -433,7 +434,8 @@ def pooling_nhwc_min( @linalg_structured_op def pooling_ndhwc_sum( - I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), @@ -451,7 +453,8 @@ def pooling_ndhwc_sum( @linalg_structured_op def pooling_ndhwc_max( - I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), @@ -470,7 +473,8 @@ def pooling_ndhwc_max( @linalg_structured_op def pooling_ndhwc_min( - I=TensorDef(T1, S.N, S.D, S.H, S.W, S.C), + I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), strides=AttributeDef(S.SD, S.SH, S.SW), From a89a49855f65aefe27109c5dc12088ffc4133a2e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 18 Sep 2021 06:57:51 -0700 Subject: [PATCH 112/915] [mlir-c] Add getting fused loc For creating a fused loc using array of locations and metadata. Differential Revision: https://reviews.llvm.org/D110022 --- mlir/include/mlir-c/IR.h | 5 +++++ mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d875e3807..28a83cba0 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -162,6 +162,11 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet( MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller); +/// Creates a fused location with an array of locations and metadata. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, + MlirLocation const *locations, MlirAttribute metadata); + /// Creates a name location owned by the given context. Providing null location /// for childLoc is allowed and if childLoc is null location, then the behavior /// is the same as having unknown child location. diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index bbadc351d..eda176300 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -132,6 +132,14 @@ MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); } +MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, + MlirLocation const *locations, + MlirAttribute metadata) { + SmallVector locs; + ArrayRef unwrappedLocs = unwrapList(nLocations, locations, locs); + return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); +} + MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc) { if (mlirLocationIsNull(childLoc)) From 6eb555a641277abb58c01afbd1b5d71b0e4990ab Mon Sep 17 00:00:00 2001 From: MaheshRavishankar Date: Mon, 20 Sep 2021 10:40:31 -0700 Subject: [PATCH 113/915] [mlir][Linalg] Add ConvolutionOpInterface. Add an interface that allows grouping together all covolution and pooling ops within Linalg named ops. The interface currently - the indexing map used for input/image access is valid - the filter and output are accessed using projected permutations - that all loops are charecterizable as one iterating over - batch dimension, - output image dimensions, - filter convolved dimensions, - output channel dimensions, - input channel dimensions, - depth multiplier (for depthwise convolutions) Differential Revision: https://reviews.llvm.org/D109793 --- .../linalg/opdsl/lang/comprehension.py | 2 +- .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index f54d2a585..c38940029 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -484,7 +484,7 @@ def __init__(self, cpp_name: str): ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") - +ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7a804b7a9..b78a21797 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -154,6 +154,7 @@ def conv_1d( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.ow, D.kw) O[D.ow] += cast( U, I[D.ow + D.kw]) * cast(U, K[D.kw]) @@ -168,6 +169,7 @@ def conv_2d( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) O[D.oh, D.ow] += cast( U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) @@ -182,6 +184,7 @@ def conv_3d( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) O[D.od, D.oh, D.ow] += cast( U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw]) @@ -198,6 +201,7 @@ def conv_1d_nwc_wcf( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) O[D.n, D.ow, D.f] += cast( U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c @@ -219,6 +223,7 @@ def conv_2d_nhwc_hwcf( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -243,6 +248,7 @@ def conv_2d_nhwc_hwcf_q( them to the same data type as the accumulator/output. This includes the zero point offsets common to quantized operations. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += (cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -264,6 +270,7 @@ def conv_2d_nchw_fchw( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.f, D.oh, D.ow] += cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW @@ -282,6 +289,7 @@ def conv_3d_ndhwc_dhwcf( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.f] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c @@ -300,6 +308,7 @@ def depthwise_conv2D_nhw( them to the same data type as the accumulator/output. Multiplier is set to 1 which is a special case for most dpethwise convolutions. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -319,6 +328,7 @@ def depthwise_conv2D_nhw_q( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -337,6 +347,7 @@ def depthwise_conv2D_nhwc( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -356,6 +367,7 @@ def depthwise_conv2D_nhwc_q( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, D.cm] += ( (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -375,6 +387,7 @@ def pooling_nhwc_sum( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] += cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -392,6 +405,7 @@ def pooling_nhwc_max( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -409,6 +423,7 @@ def pooling_nchw_max( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -426,6 +441,7 @@ def pooling_nhwc_min( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, @@ -445,6 +461,7 @@ def pooling_ndhwc_sum( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] += cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, @@ -464,6 +481,7 @@ def pooling_ndhwc_max( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( cast( @@ -484,6 +502,7 @@ def pooling_ndhwc_min( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( cast( From 733da3032a9df8c83b0220474b9d4188e17ab414 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 18 Sep 2021 21:12:15 -0700 Subject: [PATCH 114/915] [mlir][python] Forward _OperationBase _CAPIPtr to the Operation. * ODS generated operations extend _OperationBase and without this, cannot be marshalled to CAPI functions. * No test case updates: this kind of interop is quite hard to verify with in-tree tests. Differential Revision: https://reviews.llvm.org/D110030 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8672b772e..7763f4671 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2026,6 +2026,10 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of Operation. //---------------------------------------------------------------------------- py::class_(m, "_OperationBase", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); From 6fa381814414382c5239414ba786e46427d6a1be Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 21 Sep 2021 01:40:22 +0000 Subject: [PATCH 115/915] [mlir] Add value_begin/value_end methods to DenseElementsAttr Currently DenseElementsAttr only exposes the ability to get the full range of values for a given type T, but there are many situations where we just want the beginning/end iterator. This revision adds proper value_begin/value_end methods for all of the supported T types, and also cleans up a bit of the interface. Differential Revision: https://reviews.llvm.org/D104173 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 32 ++++++++------------------ 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 4ae54d4ca..a2ee06722 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -505,48 +505,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return *( - unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return *(unwrap(attr).cast().getValues().begin() + - pos); + return unwrap(attr).cast().getFlatValue(pos); } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - *(unwrap(attr).cast().getValues().begin() + - pos)); + unwrap(attr).cast().getFlatValue(pos)); } //===----------------------------------------------------------------------===// From bd8ea5d4e3ed22d2f8628038e4bc14fb1172f9cc Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 22 Sep 2021 19:50:22 -0700 Subject: [PATCH 116/915] [MLIR] [Python] Make Attribute and Type hashable Enables putting types and attributes in sets and in dicts as keys. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110301 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7763f4671..473c94c90 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2364,6 +2364,7 @@ void mlir::python::populateIRCore(py::module &m) { .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) + .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) @@ -2457,6 +2458,7 @@ void mlir::python::populateIRCore(py::module &m) { "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def("__eq__", [](PyType &self, py::object &other) { return false; }) + .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) .def( "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( From 1c2d8195c0926b75399c40f83fd140b775ebe331 Mon Sep 17 00:00:00 2001 From: Sean Silva Date: Tue, 28 Sep 2021 21:58:51 +0000 Subject: [PATCH 117/915] [mlir][Python] Fix lifetime of ExecutionEngine runtime functions. We weren't retaining the ctypes closures that the ExecutionEngine was calling back into, leading to mysterious errors. Open to feedback about how to test this. And an extra pair of eyes to make sure I caught all the places that need to be aware of this. Differential Revision: https://reviews.llvm.org/D110661 --- .../Bindings/Python/ExecutionEngineModule.cpp | 25 ++++++++++++++++--- mlir/python/mlir/execution_engine.py | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 765ff2826..07c35163c 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -31,12 +31,21 @@ class PyExecutionEngine { } MlirExecutionEngine get() { return executionEngine; } - void release() { executionEngine.ptr = nullptr; } + void release() { + executionEngine.ptr = nullptr; + referencedObjects.clear(); + } pybind11::object getCapsule() { return py::reinterpret_steal( mlirPythonExecutionEngineToCapsule(get())); } + // Add an object to the list of referenced objects whose lifetime must exceed + // those of the ExecutionEngine. + void addReferencedObject(pybind11::object obj) { + referencedObjects.push_back(obj); + } + static pybind11::object createFromCapsule(pybind11::object capsule) { MlirExecutionEngine rawPm = mlirPythonCapsuleToExecutionEngine(capsule.ptr()); @@ -47,6 +56,10 @@ class PyExecutionEngine { private: MlirExecutionEngine executionEngine; + // We support Python ctypes closures as callbacks. Keep a list of the objects + // so that they don't get garbage collected. (The ExecutionEngine itself + // just holds raw pointers with no lifetime semantics). + std::vector referencedObjects; }; } // anonymous namespace @@ -96,13 +109,17 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { .def( "raw_register_runtime", [](PyExecutionEngine &executionEngine, const std::string &name, - uintptr_t sym) { + py::object callbackObj) { + executionEngine.addReferencedObject(callbackObj); + uintptr_t rawSym = + py::cast(py::getattr(callbackObj, "value")); mlirExecutionEngineRegisterSymbol( executionEngine.get(), mlirStringRefCreate(name.c_str(), name.size()), - reinterpret_cast(sym)); + reinterpret_cast(rawSym)); }, - "Lookup function `func` in the ExecutionEngine.") + py::arg("name"), py::arg("callback"), + "Register `callback` as the runtime symbol `name`.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py index 1c516ae5a..262545b9c 100644 --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -39,5 +39,5 @@ def register_runtime(self, name, ctypes_callback): under the provided `name`. The `ctypes_callback` must be a `CFuncType` that outlives the execution engine. """ - callback = ctypes.cast(ctypes_callback, ctypes.c_void_p).value + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) self.raw_register_runtime("_mlir_ciface_" + name, callback) From 9a114dcc6e2a0b079a116d0c754ff6462cac53ce Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 29 Sep 2021 21:42:10 +0200 Subject: [PATCH 118/915] [mlir][python] provide access to function argument/result attributes Without this change, these attributes can only be accessed through the generic operation attribute dictionary provided the caller knows the special operation attribute names used for this purpose. Add some Python wrapping to support this use case. Also provide access to function arguments usable inside the function along with a couple of quality-of-life improvements in using block arguments (function arguments being the arguments of its entry block). Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110758 --- mlir/lib/Bindings/Python/IRCore.cpp | 39 ++++++++++--------- mlir/lib/Bindings/Python/PybindUtils.h | 18 ++++++++- mlir/python/mlir/dialects/_builtin_ops_ext.py | 22 +++++++++++ 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 473c94c90..0434ac37c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1594,32 +1594,35 @@ class PyOpResult : public PyConcreteValue { /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in /// Python bindings) and extends its lifetime. -class PyBlockArgumentList { +class PyBlockArgumentList + : public Sliceable { public: - PyBlockArgumentList(PyOperationRef operation, MlirBlock block) - : operation(std::move(operation)), block(block) {} + static constexpr const char *pyClassName = "BlockArgumentList"; - /// Returns the length of the block argument list. - intptr_t dunderLen() { + PyBlockArgumentList(PyOperationRef operation, MlirBlock block, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumArguments(block) : length, + step), + operation(std::move(operation)), block(block) {} + + /// Returns the number of arguments in the list. + intptr_t getNumElements() { operation->checkValid(); return mlirBlockGetNumArguments(block); } - /// Returns `index`-th element of the block argument list. - PyBlockArgument dunderGetItem(intptr_t index) { - if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); - } - PyValue value(operation, mlirBlockGetArgument(block, index)); - return PyBlockArgument(value); + /// Returns `pos`-the element in the list. Asserts on out-of-bounds. + PyBlockArgument getElement(intptr_t pos) { + MlirValue argument = mlirBlockGetArgument(block, pos); + return PyBlockArgument(operation, argument); } - /// Defines a Python class in the bindings. - static void bind(py::module &m) { - py::class_(m, "BlockArgumentList", py::module_local()) - .def("__len__", &PyBlockArgumentList::dunderLen) - .def("__getitem__", &PyBlockArgumentList::dunderGetItem); + /// Returns a sublist of this list. + PyBlockArgumentList slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockArgumentList(operation, block, startIndex, length, step); } private: diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 9fecc7ce3..2fdcf695b 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -260,13 +260,29 @@ class Sliceable { sliceLength, step * extraStep); } + /// Returns a new vector (mapped to Python list) containing elements from two + /// slices. The new vector is necessary because slices may not be contiguous + /// or even come from the same original sequence. + std::vector dunderAdd(Derived &other) { + std::vector elements; + elements.reserve(length + other.length); + for (intptr_t i = 0; i < length; ++i) { + elements.push_back(dunderGetItem(i)); + } + for (intptr_t i = 0; i < other.length; ++i) { + elements.push_back(other.dunderGetItem(i)); + } + return elements; + } + /// Binds the indexing and length methods in the Python class. static void bind(pybind11::module &m) { auto clazz = pybind11::class_(m, Derived::pyClassName, pybind11::module_local()) .def("__len__", &Sliceable::dunderLen) .def("__getitem__", &Sliceable::dunderGetItem) - .def("__getitem__", &Sliceable::dunderGetItemSlice); + .def("__getitem__", &Sliceable::dunderGetItemSlice) + .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); } diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index 99783d833..d464819f2 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -11,6 +11,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" class ModuleOp: """Specialization for the module op class.""" @@ -100,6 +102,26 @@ def add_entry_block(self): self.body.blocks.append(*self.type.inputs) return self.body.blocks[0] + @property + def arg_attrs(self): + return self.attributes[ARGUMENT_ATTRIBUTE_NAME] + + @arg_attrs.setter + def arg_attrs(self, attribute: ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + @classmethod def from_py_func(FuncOp, *inputs: Type, From 119c5174972145da43316a4d765162c8fb7618c5 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 29 Sep 2021 21:42:53 +0200 Subject: [PATCH 119/915] [mlir][python] provide bindings for the SCF dialect This is an important core dialect that has not been exposed previously. Set up the default bindings generation and provide a nicer wrapper for the `for` loop with access to the loop configuration and body. Depends On D110758 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110759 --- mlir/python/CMakeLists.txt | 9 ++++ mlir/python/mlir/dialects/SCFOps.td | 15 ++++++ mlir/python/mlir/dialects/_scf_ops_ext.py | 57 +++++++++++++++++++++++ mlir/python/mlir/dialects/scf.py | 5 ++ 4 files changed, 86 insertions(+) create mode 100644 mlir/python/mlir/dialects/SCFOps.td create mode 100644 mlir/python/mlir/dialects/_scf_ops_ext.py create mode 100644 mlir/python/mlir/dialects/scf.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 506d8ead2..2ab3a9af1 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -109,6 +109,15 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/python_test.py DIALECT_NAME python_test) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFOps.td + SOURCES + dialects/scf.py + dialects/_scf_ops_ext.py + DIALECT_NAME scf) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td new file mode 100644 index 000000000..855482d4a --- /dev/null +++ b/mlir/python/mlir/dialects/SCFOps.td @@ -0,0 +1,15 @@ +//===-- SCFOps.td - Entry point for SCF dialect bindings ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SCF_OPS +#define PYTHON_BINDINGS_SCF_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SCF/SCFOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py new file mode 100644 index 000000000..c6532a756 --- /dev/null +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -0,0 +1,57 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Any, Sequence + + +class ForOp: + """Specialization for the SCF for op class.""" + + def __init__(self, + lower_bound, + upper_bound, + step, + iter_args: Sequence[Any] = [], + *, + loc=None, + ip=None): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments. + """ + results = [arg.type for arg in iter_args] + super().__init__( + self.build_generic( + regions=1, + results=results, + operands=[lower_bound, upper_bound, step] + list(iter_args), + loc=loc, + ip=ip)) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py new file mode 100644 index 000000000..302a49d56 --- /dev/null +++ b/mlir/python/mlir/dialects/scf.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._scf_ops_gen import * From e66cf7d05ac2fcc7e96e21bde795673eb7abf06d Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Thu, 30 Sep 2021 13:50:31 +0200 Subject: [PATCH 120/915] [mlir] Remove unused namespace alias. --- mlir/lib/Bindings/Python/Conversions/Conversions.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp index f8b3b2041..c9d380178 100644 --- a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp +++ b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp @@ -10,8 +10,6 @@ #include -namespace py = pybind11; - // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- From b3e0f0889f0951a7510388a53d20d6ca878c272f Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 30 Sep 2021 15:09:30 +0200 Subject: [PATCH 121/915] [mlir][python] provide bindings for ops from the sparse_tensor dialect Previously, the dialect was exposed for linking and pass management purposes, but we did not generate op classes for it. Generate them. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110819 --- mlir/python/CMakeLists.txt | 1 + mlir/python/mlir/dialects/SparseTensorOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/sparse_tensor.py | 1 + 3 files changed, 17 insertions(+) create mode 100644 mlir/python/mlir/dialects/SparseTensorOps.td diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 2ab3a9af1..eb7e1e40d 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -128,6 +128,7 @@ declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SparseTensorOps.td SOURCES dialects/sparse_tensor.py DIALECT_NAME sparse_tensor) diff --git a/mlir/python/mlir/dialects/SparseTensorOps.td b/mlir/python/mlir/dialects/SparseTensorOps.td new file mode 100644 index 000000000..b3b4846db --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorOps.td @@ -0,0 +1,15 @@ +//===-- SparseTensorOps.td - Entry point for bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_OPS +#define PYTHON_BINDINGS_SPARSE_TENSOR_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py index 4a89ef8ae..4f6b675ec 100644 --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -2,5 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from ._sparse_tensor_ops_gen import * from .._mlir_libs._mlir.dialects.sparse_tensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses From 1e8d7e19d9aaf71dfd24fdd0ed1b5e516b4307c4 Mon Sep 17 00:00:00 2001 From: Daniel Resnick Date: Thu, 30 Sep 2021 18:14:00 -0600 Subject: [PATCH 122/915] [mlir][capi] Add TypeID to MLIR C-API Exposes mlir::TypeID to the C API as MlirTypeID along with various accessors and helper functions. Differential Revision: https://reviews.llvm.org/D110897 --- mlir/include/mlir-c/IR.h | 27 +++++++++++++++++++++++++++ mlir/include/mlir/CAPI/IR.h | 1 + mlir/lib/CAPI/IR/IR.cpp | 28 ++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 28a83cba0..92697a248 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -60,6 +60,7 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void); DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); DEFINE_C_API_STRUCT(MlirType, const void); +DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirValue, const void); #undef DEFINE_C_API_STRUCT @@ -356,6 +357,11 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); +/// Gets the type id of the operation. +/// Returns null if the operation does not have a registered operation +/// description. +MLIR_CAPI_EXPORTED MlirTypeID mlirOperationGetTypeID(MlirOperation op); + /// Gets the name of the operation as an identifier. MLIR_CAPI_EXPORTED MlirIdentifier mlirOperationGetName(MlirOperation op); @@ -626,6 +632,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTypeParseGet(MlirContext context, /// Gets the context that a type was created with. MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type); +/// Gets the type ID of the type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type); + /// Checks whether a type is null. static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; } @@ -655,6 +664,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirAttributeGetContext(MlirAttribute attribute); /// Gets the type of this attribute. MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute); +/// Gets the type id of the attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } @@ -693,6 +705,21 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident, /// Gets the string value of the identifier. MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +/// Checks whether a type id is null. +MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { + return !typeID.ptr; +} + +/// Checks if two type ids are equal. +MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); + +/// Returns the hash value of the type id. +MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index ea7b265dd..d5e961367 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -33,6 +33,7 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) +DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_METHODS(MlirValue, mlir::Value) #endif // MLIR_CAPI_IR_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index eda176300..ee5a55511 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -23,6 +23,7 @@ #include "mlir/Parser.h" #include "llvm/Support/Debug.h" +#include using namespace mlir; @@ -345,6 +346,13 @@ MlirContext mlirOperationGetContext(MlirOperation op) { return wrap(unwrap(op)->getContext()); } +MlirTypeID mlirOperationGetTypeID(MlirOperation op) { + if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) { + return wrap(abstractOp->typeID); + } + return {nullptr}; +} + MlirIdentifier mlirOperationGetName(MlirOperation op) { return wrap(unwrap(op)->getName().getIdentifier()); } @@ -658,6 +666,10 @@ MlirContext mlirTypeGetContext(MlirType type) { return wrap(unwrap(type).getContext()); } +MlirTypeID mlirTypeGetTypeID(MlirType type) { + return wrap(unwrap(type).getTypeID()); +} + bool mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } @@ -685,6 +697,10 @@ MlirType mlirAttributeGetType(MlirAttribute attribute) { return wrap(unwrap(attribute).getType()); } +MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { + return wrap(unwrap(attr).getTypeID()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } @@ -721,3 +737,15 @@ bool mlirIdentifierEqual(MlirIdentifier ident, MlirIdentifier other) { MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { return wrap(unwrap(ident).strref()); } + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { + return unwrap(typeID1) == unwrap(typeID2); +} + +size_t mlirTypeIDHashValue(MlirTypeID typeID) { + return hash_value(unwrap(typeID)); +} From bf3b6e95f9b59fbbb5edb2ccee48e3a2c640c6ed Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 4 Oct 2021 11:38:20 +0200 Subject: [PATCH 123/915] [mlir][python] Usability improvements for Python bindings Provide a couple of quality-of-life usability improvements for Python bindings, in particular: * give access to the list of types for the list of op results or block arguments, similarly to ValueRange->TypeRange, * allow for constructing empty dictionary arrays, * support construction of array attributes by concatenating an existing attribute with a Python list of attributes. All these are required for the upcoming customization of builtin and standard ops. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110946 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 56 +++++++++++++++-------- mlir/lib/Bindings/Python/IRCore.cpp | 25 ++++++++++ 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index bb4b5f4f0..2ff75ceed 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -18,7 +18,6 @@ using namespace mlir; using namespace mlir::python; using llvm::SmallVector; -using llvm::StringRef; using llvm::Twine; namespace { @@ -44,6 +43,24 @@ class PyAffineMapAttribute : public PyConcreteAttribute { } }; +template +static T pyTryCast(py::handle object) { + try { + return object.cast(); + } catch (py::cast_error &err) { + std::string msg = + std::string( + "Invalid attribute when attempting to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { + std::string msg = std::string("Invalid attribute (None?) when attempting " + "to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } +} + class PyArrayAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; @@ -76,6 +93,10 @@ class PyArrayAttribute : public PyConcreteAttribute { int nextIndex = 0; }; + PyAttribute getItem(intptr_t i) { + return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + } + static void bindDerived(ClassTy &c) { c.def_static( "get", @@ -83,21 +104,7 @@ class PyArrayAttribute : public PyConcreteAttribute { SmallVector mlirAttributes; mlirAttributes.reserve(py::len(attributes)); for (auto attribute : attributes) { - try { - mlirAttributes.push_back(attribute.cast()); - } catch (py::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting " - "to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { - // This exception seems thrown when the value is "None". - std::string msg = - std::string("Invalid attribute (None?) when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } + mlirAttributes.push_back(pyTryCast(attribute)); } MlirAttribute attr = mlirArrayAttrGet( context->get(), mlirAttributes.size(), mlirAttributes.data()); @@ -109,8 +116,7 @@ class PyArrayAttribute : public PyConcreteAttribute { [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) throw py::index_error("ArrayAttribute index out of range"); - return PyAttribute(arr.getContext(), - mlirArrayAttrGetElement(arr, i)); + return arr.getItem(i); }) .def("__len__", [](const PyArrayAttribute &arr) { @@ -119,6 +125,18 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); + c.def("__add__", [](PyArrayAttribute arr, py::list extras) { + std::vector attributes; + intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); + attributes.reserve(numOldElements + py::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + attributes.push_back(arr.getItem(i)); + for (py::handle attr : extras) + attributes.push_back(pyTryCast(attr)); + MlirAttribute arrayAttr = mlirArrayAttrGet( + arr.getContext()->get(), attributes.size(), attributes.data()); + return PyArrayAttribute(arr.getContext(), arrayAttr); + }); } }; @@ -602,7 +620,7 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + py::arg("value") = py::dict(), py::arg("context") = py::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0434ac37c..8ed3bd5ed 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1590,6 +1590,19 @@ class PyOpResult : public PyConcreteValue { } }; +/// Returns the list of types of the values held by container. +template +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; + result.reserve(container.getNumElements()); + for (int i = 0, e = container.getNumElements(); i < e; ++i) { + result.push_back( + PyType(context, mlirValueGetType(container.getElement(i).get()))); + } + return result; +} + /// A list of block arguments. Internally, these are stored as consecutive /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in @@ -1625,6 +1638,12 @@ class PyBlockArgumentList return PyBlockArgumentList(operation, block, startIndex, length, step); } + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + private: PyOperationRef operation; MlirBlock block; @@ -1712,6 +1731,12 @@ class PyOpResultList : public Sliceable { return PyOpResultList(operation, startIndex, length, step); } + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + private: PyOperationRef operation; }; From 21383b2e06fd372dd0ff06412427910910c5ee8f Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 4 Oct 2021 11:38:53 +0200 Subject: [PATCH 124/915] [mlir][python] Provide more convenient wrappers for std.ConstantOp Constructing a ConstantOp using the default-generated API is verbose and requires to specify the constant type twice: for the result type of the operation and for the type of the attribute. It also requires to explicitly construct the attribute. Provide custom constructors that take the type once and accept a raw value instead of the attribute. This requires dynamic dispatch based on type in the constructor. Also provide the corresponding accessors to raw values. In addition, provide a "refinement" class ConstantIndexOp similar to what exists in C++. Unlike other "op view" Python classes, operations cannot be automatically downcasted to this class since it does not correspond to a specific operation name. It only exists to simplify construction of the operation. Depends On D110946 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110947 --- mlir/python/CMakeLists.txt | 4 +- mlir/python/mlir/dialects/_std_ops_ext.py | 71 +++++++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) create mode 100644 mlir/python/mlir/dialects/_std_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index eb7e1e40d..4f0d1548e 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -136,7 +136,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/StandardOps.td - SOURCES dialects/std.py + SOURCES + dialects/std.py + dialects/_std_ops_ext.py DIALECT_NAME std) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py new file mode 100644 index 000000000..bb67fe44d --- /dev/null +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -0,0 +1,71 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from .builtin import FuncOp + from ._ods_common import get_default_loc_context as _get_default_loc_context + + from typing import Any, List, Optional, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +class ConstantOp: + """Specialization for the constant op class.""" + + def __init__(self, + result: Type, + value: Union[int, float, Attribute], + *, + loc=None, + ip=None): + if isinstance(value, int): + super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(result, value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), + value, + loc=loc, + ip=ip) + + @property + def type(self): + return self.results[0].type + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") From 123dd34203ae618deaa6e27034c79255088c8b59 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 4 Oct 2021 11:39:19 +0200 Subject: [PATCH 125/915] [mlir][python] Provide more convenient constructors for std.CallOp The new constructor relies on type-based dynamic dispatch and allows one to construct call operations given an object representing a FuncOp or its name as a string, as opposed to requiring an explicitly constructed attribute. Depends On D110947 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D110948 --- mlir/python/mlir/dialects/_builtin_ops_ext.py | 16 +++-- mlir/python/mlir/dialects/_std_ops_ext.py | 70 +++++++++++++++++++ 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index d464819f2..462850d63 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Optional, Sequence + from typing import Optional, Sequence, Union import inspect @@ -82,8 +82,8 @@ def visibility(self): return self.attributes["sym_visibility"] @property - def name(self): - return self.attributes["sym_name"] + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) @property def entry_block(self): @@ -104,11 +104,15 @@ def add_entry_block(self): @property def arg_attrs(self): - return self.attributes[ARGUMENT_ATTRIBUTE_NAME] + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter - def arg_attrs(self, attribute: ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) @property def arguments(self): diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py index bb67fe44d..39da1c83f 100644 --- a/mlir/python/mlir/dialects/_std_ops_ext.py +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -69,3 +69,73 @@ def literal_value(self) -> Union[int, float]: return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") + + +class CallOp: + """Specialization for the call op class.""" + + def __init__(self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = builtin.FuncOp("foo", ...) + std.CallOp(f, [args]) + std.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}") + if arguments is not None: + raise ValueError("unexpected third argument when constructing a call" + + "to a function") + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, + context=_get_default_loc_context(loc)), + argumentsOrCallee, + loc=loc, + ip=ip) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError("when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc)), + arguments, + loc=loc, + ip=ip) From 38f40691c229dc2b38b68d3e5ce3e7f1a69e0b57 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Wed, 6 Oct 2021 06:45:42 +0000 Subject: [PATCH 126/915] [mlir][linalg] Update OpDSL to use the newly introduced min and max ops. Implement min and max using the newly introduced std operations instead of relying on compare and select. Reviewed By: dcaballe Differential Revision: https://reviews.llvm.org/D111170 --- .../dialects/linalg/opdsl/lang/emitter.py | 22 ++++--------------- 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index b151a9ba9..4a883e790 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -319,20 +319,16 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value: def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - ogt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpf_and_select(lhs, rhs, ogt_attr) + return std.MaxFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - sgt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpi_and_select(lhs, rhs, sgt_attr) + return std.MaxSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - olt_attr = IntegerAttr.get(IntegerType.get_signless(64), 4) - return _emit_cmpf_and_select(lhs, rhs, olt_attr) + return std.MinFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - slt_attr = IntegerAttr.get(IntegerType.get_signless(64), 2) - return _emit_cmpi_and_select(lhs, rhs, slt_attr) + return std.MinSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") @@ -413,13 +409,3 @@ def _get_floating_point_width(t: Type) -> int: if BF16Type.isinstance(t): return 16 raise NotImplementedError(f"Unhandled floating point type switch {t}") - - -def _emit_cmpf_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpFOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result - - -def _emit_cmpi_and_select(lhs: Value, rhs: Value, pred: IntegerAttr) -> Value: - cond = std.CmpIOp(IntegerType.get_signless(1), pred, lhs, rhs).result - return std.SelectOp(lhs.type, cond, lhs, rhs).result From 070ac1c4c47edc6dc4b57dd2edbcd46ffe80b9f0 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Thu, 7 Oct 2021 06:26:38 +0000 Subject: [PATCH 127/915] [mlir][linalg] Add unsigned min/max/cast function to OpDSL. Update OpDSL to support unsigned integers by adding unsigned min/max/cast signatures. Add tests in OpDSL and on the C++ side to verify the proper signed and unsigned operations are emitted. The patch addresses an issue brought up in https://reviews.llvm.org/D111170. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D111230 --- .../linalg/opdsl/lang/comprehension.py | 19 ++++++- .../dialects/linalg/opdsl/lang/emitter.py | 35 ++++++++++--- .../dialects/linalg/opdsl/lang/scalar_expr.py | 9 ++-- .../linalg/opdsl/ops/core_named_ops.py | 49 +++++++++++++++++++ 4 files changed, 101 insertions(+), 11 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index c38940029..732cacfff 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -340,6 +340,8 @@ class PrimFn: max = PrimFnType("max") min = PrimFnType("min") sub = PrimFnType("sub") + max_unsigned = PrimFnType("max_unsigned") + min_unsigned = PrimFnType("min_unsigned") class ReduceFnType: @@ -365,6 +367,8 @@ class ReduceFn: mul = PrimFn.mul.reduce max = PrimFn.max.reduce min = PrimFn.min.reduce + max_unsigned = PrimFn.max_unsigned.reduce + min_unsigned = PrimFn.min_unsigned.reduce class PrimApply(TensorExpression): @@ -438,8 +442,8 @@ def __init__(self, to_type: TypeVar, operand: TensorExpression): self.operand = operand def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, - self.operand.to_scalar_expression()).expr() + return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), + False).expr() def visit_tensor_exprs(self, callback): super().visit_tensor_exprs(callback) @@ -449,6 +453,17 @@ def __repr__(self): return f"cast({self.to_type}, {repr(self.operand)})" +class cast_unsigned(cast): + """Casts the element type to an unsigned type (typically symbolic TypeVar).""" + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), + True).expr() + + def __repr__(self): + return f"cast_unsigned({self.to_type}, {repr(self.operand)})" + + class ReduceApply(TensorExpression): """Application of a reduction. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 4a883e790..7feea040a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -230,10 +230,12 @@ def expression(self, expr: ScalarExpression) -> Value: return fn(*operand_values) elif expr.symbolic_cast: operand_value = self.expression(expr.symbolic_cast.operand) - return self.cast(expr.symbolic_cast.to_type.name, operand_value) + return self.cast(expr.symbolic_cast.to_type.name, operand_value, + expr.symbolic_cast.is_unsigned_cast) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def cast(self, type_var_name: str, operand: Value) -> Value: + def cast(self, type_var_name: str, operand: Value, + is_unsigned_cast: bool) -> Value: try: to_type = self.type_mapping[type_var_name] except KeyError: @@ -242,29 +244,37 @@ def cast(self, type_var_name: str, operand: Value) -> Value: if operand.type == to_type: return operand if _is_integer_type(to_type): - return self._cast_to_integer(to_type, operand) + return self._cast_to_integer(to_type, operand, is_unsigned_cast) elif _is_floating_point_type(to_type): - return self._cast_to_floating_point(to_type, operand) + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) - def _cast_to_integer(self, to_type: Type, operand: Value) -> Value: + def _cast_to_integer(self, to_type: Type, operand: Value, + is_unsigned_cast: bool) -> Value: to_width = IntegerType(to_type).width operand_type = operand.type if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return std.FPToUIOp(to_type, operand).result return std.FPToSIOp(to_type, operand).result if _is_index_type(operand_type): return std.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: + if is_unsigned_cast: + return std.ZeroExtendIOp(to_type, operand).result return std.SignExtendIOp(to_type, operand).result elif to_width < from_width: return std.TruncateIOp(to_type, operand).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value: + def _cast_to_floating_point(self, to_type: Type, operand: Value, + is_unsigned_cast: bool) -> Value: operand_type = operand.type if _is_integer_type(operand_type): + if is_unsigned_cast: + return std.UIToFPOp(to_type, operand).result return std.SIToFPOp(to_type, operand).result # Assume FloatType. to_width = _get_floating_point_width(to_type) @@ -324,6 +334,13 @@ def _eval_max(self, lhs: Value, rhs: Value) -> Value: return std.MaxSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") + def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MaxFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.MaxUIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") + def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return std.MinFOp(lhs.type, lhs, rhs).result @@ -331,6 +348,12 @@ def _eval_min(self, lhs: Value, rhs: Value) -> Value: return std.MinSIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") + def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return std.MinFOp(lhs.type, lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return std.MinUIOp(lhs.type, lhs, rhs).result + raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") def _infer_structured_outs(op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index 48627bfab..6de3333fb 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -85,15 +85,17 @@ def __repr__(self): class ScalarSymbolicCast: """A type of ScalarExpression that symbolically casts an operand to a TypeVar.""" - def __init__(self, to_type: TypeVar, operand: "ScalarExpression"): + def __init__(self, to_type: TypeVar, operand: "ScalarExpression", + is_unsigned_cast: bool): self.to_type = to_type self.operand = operand + self.is_unsigned_cast = is_unsigned_cast def expr(self) -> "ScalarExpression": return ScalarExpression(symbolic_cast=self) def __repr__(self): - return f"ScalarSymbolicCast({self.to_type}, {self.operand})" + return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})" class ScalarExpression(YAMLObject): @@ -144,7 +146,8 @@ def to_yaml_custom_dict(self): return dict( symbolic_cast=dict( type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand])) + operands=[self.symbolic_cast.operand], + is_unsigned_cast=self.symbolic_cast.is_unsigned_cast)) else: raise ValueError(f"Unexpected ScalarExpression type: {self}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b78a21797..9f5b27ea0 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -20,6 +20,20 @@ def matmul( implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +@linalg_structured_op +def matmul_unsigned( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs an unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), @@ -411,6 +425,24 @@ def pooling_nhwc_max( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_nchw_max( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), @@ -447,6 +479,23 @@ def pooling_nhwc_min( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=AttributeDef(S.SH, S.SW), + dilations=AttributeDef(S.DH, S.DW)): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op def pooling_ndhwc_sum( From 4f649496cda695870564373e473930f7d6cca793 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 6 Oct 2021 18:41:22 -0700 Subject: [PATCH 128/915] [mlir] Extend C and Python API to support bulk loading of DenseElementsAttr. * This already half existed in terms of reading the raw buffer backing a DenseElementsAttr. * Documented the precise expectations of the buffer layout. * Extended the Python API to support construction from bitcasted buffers, allowing construction of all primitive element types (even those that lack a compatible representation in Python). * Specifically, the Python API can now load all integer types at all bit widths and all floating point types (f16, f32, f64, bf16). Differential Revision: https://reviews.llvm.org/D111284 --- mlir/include/mlir-c/BuiltinAttributes.h | 17 ++ mlir/lib/Bindings/Python/IRAttributes.cpp | 243 +++++++++++++++------- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 15 ++ 3 files changed, 203 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 247de5cc0..5839cd3d2 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -306,6 +306,23 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( MlirType shapedType, intptr_t numElements, MlirAttribute const *elements); +/// Creates a dense elements attribute with the given Shaped type and elements +/// populated from a packed, row-major opaque buffer of contents. +/// +/// The format of the raw buffer is a densely packed array of values that +/// can be bitcast to the storage format of the element type specified. +/// Types that are not byte aligned will be: +/// - For bitwidth > 1: Rounded up to the next byte. +/// - For bitwidth = 1: Packed into 8bit bytes with bits corresponding to +/// the linear order of the shape type from MSB to LSB, padded to on the +/// right. +/// +/// A raw buffer of a single element (or for 1-bit, a byte of value 0 or 255) +/// will be interpreted as a splat. User code should be prepared for additional, +/// conformant patterns to be identified as splats in the future. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrRawBufferGet( + MlirType shapedType, size_t rawBufferSize, const void *rawBuffer); + /// Creates a dense elements attribute with the given Shaped type containing a /// single replicated element (splat). MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 2ff75ceed..47f73ecae 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -17,9 +17,57 @@ namespace py = pybind11; using namespace mlir; using namespace mlir::python; +using llvm::None; +using llvm::Optional; using llvm::SmallVector; using llvm::Twine; +//------------------------------------------------------------------------------ +// Docstrings (trivial, non-duplicated docstrings are included inline). +//------------------------------------------------------------------------------ + +static const char kDenseElementsAttrGetDocstring[] = + R"(Gets a DenseElementsAttr from a Python buffer or array. + +When `type` is not provided, then some limited type inferencing is done based +on the buffer format. Support presently exists for 8/16/32/64 signed and +unsigned integers and float16/float32/float64. DenseElementsAttrs of these +types can also be converted back to a corresponding buffer. + +For conversions outside of these types, a `type=` must be explicitly provided +and the buffer contents must be bit-castable to the MLIR internal +representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + +If a single element buffer is passed (or for i1, a single byte with value 0 +or 255), then a splat will be created. + +Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + namespace { static MlirStringRef toMlirStringRef(const std::string &s) { @@ -301,7 +349,6 @@ class PyStringAttribute : public PyConcreteAttribute { } }; -// TODO: Support construction of bool elements. // TODO: Support construction of string elements. class PyDenseElementsAttribute : public PyConcreteAttribute { @@ -311,7 +358,8 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, + getFromBuffer(py::buffer array, bool signless, Optional explicitType, + Optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; @@ -321,69 +369,95 @@ class PyDenseElementsAttribute throw py::error_already_set(); } py::buffer_info arrayInfo(view); + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(arrayInfo.shape.begin(), + arrayInfo.shape.begin() + arrayInfo.ndim); + } + MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - // Switch on the types that can be bulk loaded between the Python and - // MLIR-C APIs. - // See: https://docs.python.org/3/library/struct.html#format-characters - if (arrayInfo.format == "f") { + + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes, bool (which needs to be bit-packed) and + // other exotics which do not have a direct representation in the buffer + // protocol (i.e. complex, etc). + Optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else if (arrayInfo.format == "f") { // f32 assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrFloatGet, - mlirF32TypeGet(context), arrayInfo)); + bulkLoadElementType = mlirF32TypeGet(context); } else if (arrayInfo.format == "d") { // f64 assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - return PyDenseElementsAttribute( - contextWrapper->getRef(), - bulkLoad(context, mlirDenseElementsAttrDoubleGet, - mlirF64TypeGet(context), arrayInfo)); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (arrayInfo.format == "e") { + // f16 + assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); } else if (isSignedIntegerFormat(arrayInfo.format)) { if (arrayInfo.itemsize == 4) { // i32 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt32Get, - elementType, arrayInfo)); + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); } else if (arrayInfo.itemsize == 8) { // i64 - MlirType elementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrInt64Get, - elementType, arrayInfo)); + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (arrayInfo.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (arrayInfo.itemsize == 2) { + // i16 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); } } else if (isUnsignedIntegerFormat(arrayInfo.format)) { if (arrayInfo.itemsize == 4) { // unsigned i32 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt32Get, - elementType, arrayInfo)); + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); } else if (arrayInfo.itemsize == 8) { // unsigned i64 - MlirType elementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - return PyDenseElementsAttribute(contextWrapper->getRef(), - bulkLoad(context, - mlirDenseElementsAttrUInt64Get, - elementType, arrayInfo)); + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (arrayInfo.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (arrayInfo.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); } } + if (bulkLoadElementType) { + auto shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; + MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( + shapedType, rawBufferSize, arrayInfo.ptr); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); + } - // TODO: Fall back to string-based get. - std::string message = "unimplemented array format conversion from format: "; - message.append(arrayInfo.format); - throw SetPyError(PyExc_ValueError, message); + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + arrayInfo.format); } static PyDenseElementsAttribute getSplat(PyType shapedType, @@ -422,47 +496,82 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } py::buffer_info accessBuffer() { + if (mlirDenseElementsAttrIsSplat(*this)) { + // TODO: Raise an exception. + // Reported as https://github.com/pybind/pybind11/issues/3336 + return py::buffer_info(); + } + MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); + std::string format; if (mlirTypeIsAF32(elementType)) { // f32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetFloatValue); + return bufferInfo(shapedType); } else if (mlirTypeIsAF64(elementType)) { // f64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetDoubleValue); + return bufferInfo(shapedType); + } else if (mlirTypeIsAF16(elementType)) { + // f16 + return bufferInfo(shapedType, "e"); } else if (mlirTypeIsAInteger(elementType) && mlirIntegerTypeGetWidth(elementType) == 32) { if (mlirIntegerTypeIsSignless(elementType) || mlirIntegerTypeIsSigned(elementType)) { // i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt32Value); + return bufferInfo(shapedType); } else if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i32 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt32Value); + return bufferInfo(shapedType); } } else if (mlirTypeIsAInteger(elementType) && mlirIntegerTypeGetWidth(elementType) == 64) { if (mlirIntegerTypeIsSignless(elementType) || mlirIntegerTypeIsSigned(elementType)) { // i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetInt64Value); + return bufferInfo(shapedType); } else if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i64 - return bufferInfo(shapedType, mlirDenseElementsAttrGetUInt64Value); + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 8) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i8 + return bufferInfo(shapedType); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i8 + return bufferInfo(shapedType); + } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 16) { + if (mlirIntegerTypeIsSignless(elementType) || + mlirIntegerTypeIsSigned(elementType)) { + // i16 + return bufferInfo(shapedType); + } else if (mlirIntegerTypeIsUnsigned(elementType)) { + // unsigned i16 + return bufferInfo(shapedType); } } - std::string message = "unimplemented array format."; - throw SetPyError(PyExc_ValueError, message); + // TODO: Currently crashes the program. Just returning an empty buffer + // for now. + // Reported as https://github.com/pybind/pybind11/issues/3336 + // throw std::invalid_argument( + // "unsupported data type for conversion to Python buffer"); + return py::buffer_info(); } static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, py::arg("array"), py::arg("signless") = true, + py::arg("type") = py::none(), py::arg("shape") = py::none(), py::arg("context") = py::none(), - "Gets from a buffer or ndarray") + kDenseElementsAttrGetDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") @@ -474,21 +583,6 @@ class PyDenseElementsAttribute } private: - template - static MlirAttribute - bulkLoad(MlirContext context, - MlirAttribute (*ctor)(MlirType, intptr_t, ElementTy *), - MlirType mlirElementType, py::buffer_info &arrayInfo) { - SmallVector shape(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); - MlirAttribute encodingAttr = mlirAttributeGetNull(); - auto shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), - mlirElementType, encodingAttr); - intptr_t numElements = arrayInfo.size; - const ElementTy *contents = static_cast(arrayInfo.ptr); - return ctor(shapedType, numElements, contents); - } - static bool isUnsignedIntegerFormat(const std::string &format) { if (format.empty()) return false; @@ -507,7 +601,7 @@ class PyDenseElementsAttribute template py::buffer_info bufferInfo(MlirType shapedType, - Type (*value)(MlirAttribute, intptr_t)) { + const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -528,9 +622,14 @@ class PyDenseElementsAttribute strides.push_back(sizeof(Type) * strideFactor); } strides.push_back(sizeof(Type)); - return py::buffer_info(data, sizeof(Type), - py::format_descriptor::format(), rank, shape, - strides, /*readonly=*/true); + std::string format; + if (explicitFormat) { + format = explicitFormat; + } else { + format = py::format_descriptor::format(); + } + return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, + /*readonly=*/true); } }; // namespace diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index a2ee06722..3b15212e3 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -331,6 +331,21 @@ MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, unwrapList(numElements, elements, attributes))); } +MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, + size_t rawBufferSize, + const void *rawBuffer) { + auto shapedTypeCpp = unwrap(shapedType).cast(); + ArrayRef rawBufferCpp(static_cast(rawBuffer), + rawBufferSize); + bool isSplat = false; + if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, + isSplat)) { + return mlirAttributeGetNull(); + } + return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp, + isSplat)); +} + MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), From 8c4ce5906d7563b7d6474d313dc8a209e868d859 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 7 Oct 2021 11:47:05 -0700 Subject: [PATCH 129/915] [mlir][python] Temporarily disable test for converting unsupported DenseElementsAttr types to a buffer. * Need to investigate the proper solution to https://github.com/pybind/pybind11/issues/3336 or engineer something different. * The attempt to produce an empty buffer_info as a workaround triggers asan/ubsan. * Usage of this API does not arise naturally in practice yet, and it is more important to be asan/crash clean than have a solution right now. * Switching back to raising an exception, even though that triggers terminate(). --- mlir/lib/Bindings/Python/IRAttributes.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 47f73ecae..066350a0a 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -497,9 +497,10 @@ class PyDenseElementsAttribute py::buffer_info accessBuffer() { if (mlirDenseElementsAttrIsSplat(*this)) { - // TODO: Raise an exception. + // TODO: Currently crashes the program. // Reported as https://github.com/pybind/pybind11/issues/3336 - return py::buffer_info(); + throw std::invalid_argument( + "unsupported data type for conversion to Python buffer"); } MlirType shapedType = mlirAttributeGetType(*this); @@ -557,12 +558,10 @@ class PyDenseElementsAttribute } } - // TODO: Currently crashes the program. Just returning an empty buffer - // for now. + // TODO: Currently crashes the program. // Reported as https://github.com/pybind/pybind11/issues/3336 - // throw std::invalid_argument( - // "unsupported data type for conversion to Python buffer"); - return py::buffer_info(); + throw std::invalid_argument( + "unsupported data type for conversion to Python buffer"); } static void bindDerived(ClassTy &c) { From 9f894fa051fbfb198f3f0c2e00899ab45904140b Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Oct 2021 18:29:03 +0200 Subject: [PATCH 130/915] [mlir][python] support taking ops instead of values in op constructors Introduce support for accepting ops instead of values when constructing ops. A single-result op can be used instead of a value, including in lists of values, and any op can be used instead of a list of values. This is similar to, but more powerful, than the C++ API that allows for implicitly casting an OpType to Value if it is statically known to have a single result - the cast in Python is based on the op dynamically having a single result, and also handles the multi-result case. This allows to build IR in a more concise way: op = dialect.produce_multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other, op) instead of having to access the results manually op = dialect.produce.multiple_results() other = dialect.produce_single_result() dialect.consume_multiple_results(other.result, op.operation.results) The dispatch is implemented directly in Python and is triggered automatically for autogenerated OpView subclasses. Extension OpView classes should use the functions provided in ods_common.py if they want to implement this behavior. An alternative could be to implement the dispatch in the C++ bindings code, but it would require to forward opaque types through all Python functions down to a binding call, which makes it hard to inspect them in Python, e.g., to obtain the types of values. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111306 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 12 +++--- mlir/python/mlir/dialects/_ods_common.py | 38 +++++++++++++++++++ mlir/python/mlir/dialects/_scf_ops_ext.py | 19 +++++++--- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 25 +++++++++--- .../dialects/linalg/opdsl/lang/emitter.py | 32 +++++++++------- 5 files changed, 97 insertions(+), 29 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 536096749..b7641c0a4 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -10,6 +10,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +from ._ods_common import get_op_result_or_value as _get_op_result_or_value def isa(cls: Type, ty: Type): try: @@ -26,11 +27,12 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None): results = [] if isa(RankedTensorType, output.type): results = [output.type] - op = self.build_generic(results=results, - operands=[value, output], - attributes=None, - loc=loc, - ip=ip) + op = self.build_generic( + results=results, + operands=[_get_op_result_or_value(o) for o in [value, output]], + attributes=None, + loc=loc, + ip=ip) OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 2fbf3545f..95c441865 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -5,11 +5,14 @@ # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. from .._mlir_libs import _mlir as _cext +from typing import Sequence as _Sequence, Union as _Union __all__ = [ "equally_sized_accessor", "extend_opview_class", "get_default_loc_context", + "get_op_result_or_value", + "get_op_results_or_values", "segmented_accessor", ] @@ -118,3 +121,38 @@ def get_default_loc_context(location=None): # Location.current raises ValueError if there is no current location. return _cext.ir.Location.current.context return location.context + + +def get_op_result_or_value( + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value] +) -> _cext.ir.Value: + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + else: + assert isinstance(arg, _cext.ir.Value) + return arg + + +def get_op_results_or_values( + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]] +) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return arg diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index c6532a756..a8924a750 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -7,8 +7,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Any, Sequence - +from typing import Any, Optional, Sequence, Union +from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values class ForOp: """Specialization for the SCF for op class.""" @@ -17,7 +17,8 @@ def __init__(self, lower_bound, upper_bound, step, - iter_args: Sequence[Any] = [], + iter_args: Optional[Union[Operation, OpView, + Sequence[Value]]] = None, *, loc=None, ip=None): @@ -26,14 +27,22 @@ def __init__(self, - `lower_bound` is the value to use as lower bound of the loop. - `upper_bound` is the value to use as upper bound of the loop. - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + results = [arg.type for arg in iter_args] super().__init__( self.build_generic( regions=1, results=results, - operands=[lower_bound, upper_bound, step] + list(iter_args), + operands=[ + _get_op_result_or_value(o) + for o in [lower_bound, upper_bound, step] + ] + list(iter_args), loc=loc, ip=ip)) self.regions[0].blocks.append(IndexType.get(), *results) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 047bde245..1acae7a7a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List +from typing import Dict, List, Sequence, Union from contextlib import contextmanager import functools @@ -10,12 +10,15 @@ import threading from ..... import ir +from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .comprehension import * from .config import * from .emitter import * _CONTEXT = threading.local() +StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]]] @contextmanager def bind_op_def(model: LinalgOpDef): @@ -37,6 +40,15 @@ def current_op_def() -> LinalgOpDef: "but none is set. Did you mean to call this in an op definition?") +def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs + + return [_get_op_result_or_value(o) for o in outs] + + class DefinedOpCallable: """Callable that wraps any defined op function.""" @@ -44,7 +56,8 @@ def __init__(self, op_name: str, model: LinalgOpDef): self.op_name = op_name self.model = model - def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs): + def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, **kwargs): """Emits the corresponding op definition as IR. Most arguments are passed through to the underlying emitter. The following @@ -73,17 +86,19 @@ def __call__(self, *ins: ir.Value, outs: Sequence[ir.Value], **kwargs): emit_generic or not ctx.is_registered_operation(fully_qualified_name)) op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] if op_config.structured_op: if emit_generic: return emit_generic_structured_op( - op_config.structured_op, *ins, outs=outs, **kwargs) + op_config.structured_op, *in_values, outs=out_values, **kwargs) else: return emit_named_structured_op( op_config.structured_op, self.op_name, self.model.metadata.cpp_class_name, - *ins, - outs=outs, + *in_values, + outs=out_values, **kwargs) raise NotImplementedError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 7feea040a..021fe8328 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, Sequence +from typing import Dict, List, Sequence, Tuple, Union from .....ir import * from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region @@ -10,6 +10,7 @@ from .... import linalg from .... import std from .... import math +from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * from .config import * @@ -18,8 +19,10 @@ __all__ = [ "emit_generic_structured_op", "emit_named_structured_op", + "ValueList", ] +ValueList = Union[Sequence[Value], OpResultList] def isa(cls: Type, ty: Type): try: @@ -30,17 +33,18 @@ def isa(cls: Type, ty: Type): def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: Sequence[Value], + *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] - # Verify outs is a sequence. - if not isinstance(outs, Sequence): - raise ValueError(f"Expected named argument outs to have type Sequence " - f"but got {type(outs)}") + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}" + ) # Arity validation. if len(ins) != len(in_arg_defs): @@ -122,7 +126,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: Sequence[Value], **attrs: Sequence[int]): + outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -153,8 +157,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, - op_class_name: str, *ins: Value, - outs: Sequence[Value], **attrs: Sequence[int]): + op_class_name: str, *ins: Value, outs: ValueList, + **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -355,11 +359,11 @@ def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: return std.MinUIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") -def _infer_structured_outs(op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[OperandDefConfig], - ins: Sequence[Value], - out_arg_defs: Sequence[OperandDefConfig], - outs: Sequence[Value]): +def _infer_structured_outs( + op_config: LinalgStructuredOpConfig, + in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], + out_arg_defs: Sequence[OperandDefConfig], + outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]: """Infers implicit outs and output types. Respects existing contents of outs if not empty. From 5fe30d9b9c2a1f130512ef05e03225122b99ff22 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 12 Oct 2021 12:45:57 -0700 Subject: [PATCH 131/915] [mlir][python] Add nameloc getter Expose the nameloc getter to Python API. Differential Revision: https://reviews.llvm.org/D111663 --- mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8ed3bd5ed..3226b7fe1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -43,6 +43,9 @@ See also: https://mlir.llvm.org/docs/LangRef/#type-system static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; +static const char kContextGetNameLocationDocString[] = + R"(Gets a Location representing a named location with optional child location)"; + static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. @@ -1970,6 +1973,19 @@ void mlir::python::populateIRCore(py::module &m) { }, py::arg("filename"), py::arg("line"), py::arg("col"), py::arg("context") = py::none(), kContextGetFileLocationDocstring) + .def_static( + "name", + [](std::string name, llvm::Optional childLoc, + DefaultingPyMlirContext context) { + return PyLocation( + context->getRef(), + mlirLocationNameGet( + context->get(), toMlirStringRef(name), + childLoc ? childLoc->get() + : mlirLocationUnknownGet(context->get()))); + }, + py::arg("name"), py::arg("childLoc") = py::none(), + py::arg("context") = py::none(), kContextGetNameLocationDocString) .def_property_readonly( "context", [](PyLocation &self) { return self.getContext().getObject(); }, From 89e66179476cb50ee1f8a47cf62b1dee637a68ed Mon Sep 17 00:00:00 2001 From: Mogball Date: Tue, 12 Oct 2021 23:14:57 +0000 Subject: [PATCH 132/915] [MLIR] Replace std ops with arith dialect ops Precursor: https://reviews.llvm.org/D110200 Removed redundant ops from the standard dialect that were moved to the `arith` or `math` dialects. Renamed all instances of operations in the codebase and in tests. Reviewed By: rriddle, jpienaar Differential Revision: https://reviews.llvm.org/D110797 --- mlir/python/CMakeLists.txt | 9 +++ mlir/python/mlir/dialects/ArithmeticOps.td | 15 ++++ mlir/python/mlir/dialects/_arith_ops_ext.py | 70 +++++++++++++++++++ mlir/python/mlir/dialects/_std_ops_ext.py | 52 +------------- mlir/python/mlir/dialects/arith.py | 5 ++ .../dialects/linalg/opdsl/lang/emitter.py | 40 ++++++----- 6 files changed, 122 insertions(+), 69 deletions(-) create mode 100644 mlir/python/mlir/dialects/ArithmeticOps.td create mode 100644 mlir/python/mlir/dialects/_arith_ops_ext.py create mode 100644 mlir/python/mlir/dialects/arith.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 4f0d1548e..ed4eb6c1a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -95,6 +95,15 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/math.py DIALECT_NAME math) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ArithmeticOps.td + SOURCES + dialects/arith.py + dialects/_arith_ops_ext.py + DIALECT_NAME arith) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/ArithmeticOps.td b/mlir/python/mlir/dialects/ArithmeticOps.td new file mode 100644 index 000000000..d14b24a09 --- /dev/null +++ b/mlir/python/mlir/dialects/ArithmeticOps.td @@ -0,0 +1,15 @@ +//===-- ArithmeticOps.td - Entry point for ArithmeticOps bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ARITHMETIC_OPS +#define PYTHON_BINDINGS_ARITHMETIC_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py new file mode 100644 index 000000000..e35f5f2a4 --- /dev/null +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -0,0 +1,70 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +class ConstantOp: + """Specialization for the constant op class.""" + + def __init__(self, + result: Type, + value: Union[int, float, Attribute], + *, + loc=None, + ip=None): + if isinstance(value, int): + super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(result, value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), + value, + loc=loc, + ip=ip) + + @property + def type(self): + return self.results[0].type + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_std_ops_ext.py index 39da1c83f..f4cb6186b 100644 --- a/mlir/python/mlir/dialects/_std_ops_ext.py +++ b/mlir/python/mlir/dialects/_std_ops_ext.py @@ -12,64 +12,16 @@ raise RuntimeError("Error loading imports from extension module") from e -def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True - - -def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) - - -def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) - - -def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) - - class ConstantOp: """Specialization for the constant op class.""" - def __init__(self, - result: Type, - value: Union[int, float, Attribute], - *, - loc=None, - ip=None): - if isinstance(value, int): - super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, float): - super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) - else: - super().__init__(result, value, loc=loc, ip=ip) - - @classmethod - def create_index(cls, value: int, *, loc=None, ip=None): - """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), - value, - loc=loc, - ip=ip) + def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): + super().__init__(result, value, loc=loc, ip=ip) @property def type(self): return self.results[0].type - @property - def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): - return IntegerAttr(self.value).value - elif _is_float_type(self.type): - return FloatAttr(self.value).value - else: - raise ValueError("only integer and float constants have literal values") - class CallOp: """Specialization for the call op class.""" diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py new file mode 100644 index 000000000..77318b286 --- /dev/null +++ b/mlir/python/mlir/dialects/arith.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._arith_ops_gen import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 021fe8328..1215d0358 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -10,6 +10,7 @@ from .... import linalg from .... import std from .... import math +from .... import arith from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * @@ -216,10 +217,10 @@ def expression(self, expr: ScalarExpression) -> Value: f"this structured op.") elif expr.scalar_const: value_attr = Attribute.parse(expr.scalar_const.value) - return std.ConstantOp(value_attr.type, value_attr).result + return arith.ConstantOp(value_attr.type, value_attr).result elif expr.scalar_index: - dim_attr = IntegerAttr.get(IntegerType.get_signless(64), - expr.scalar_index.dim) + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(IndexType.get(), dim_attr).result elif expr.scalar_apply: try: @@ -258,18 +259,18 @@ def _cast_to_integer(self, to_type: Type, operand: Value, operand_type = operand.type if _is_floating_point_type(operand_type): if is_unsigned_cast: - return std.FPToUIOp(to_type, operand).result - return std.FPToSIOp(to_type, operand).result + return arith.FPToUIOp(to_type, operand).result + return arith.FPToSIOp(to_type, operand).result if _is_index_type(operand_type): - return std.IndexCastOp(to_type, operand).result + return arith.IndexCastOp(to_type, operand).result # Assume integer. from_width = IntegerType(operand_type).width if to_width > from_width: if is_unsigned_cast: - return std.ZeroExtendIOp(to_type, operand).result - return std.SignExtendIOp(to_type, operand).result + return arith.ExtUIOp(to_type, operand).result + return arith.ExtSIOp(to_type, operand).result elif to_width < from_width: - return std.TruncateIOp(to_type, operand).result + return arith.TruncIOp(to_type, operand).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") @@ -278,15 +279,15 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value, operand_type = operand.type if _is_integer_type(operand_type): if is_unsigned_cast: - return std.UIToFPOp(to_type, operand).result - return std.SIToFPOp(to_type, operand).result + return arith.UIToFPOp(to_type, operand).result + return arith.SIToFPOp(to_type, operand).result # Assume FloatType. to_width = _get_floating_point_width(to_type) from_width = _get_floating_point_width(operand_type) if to_width > from_width: - return std.FPExtOp(to_type, operand).result + return arith.ExtFOp(to_type, operand).result elif to_width < from_width: - return std.FPTruncOp(to_type, operand).result + return arith.TruncFOp(to_type, operand).result raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") @@ -302,9 +303,9 @@ def yield_outputs(self, *output_names: str): def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.AddFOp(lhs.type, lhs, rhs).result + return arith.AddFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.AddIOp(lhs.type, lhs, rhs).result + return arith.AddIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") def _eval_exp(self, x: Value) -> Value: @@ -319,16 +320,16 @@ def _eval_log(self, x: Value) -> Value: def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.SubFOp(lhs.type, lhs, rhs).result + return arith.SubFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.SubIOp(lhs.type, lhs, rhs).result + return arith.SubIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MulFOp(lhs.type, lhs, rhs).result + return arith.MulFOp(lhs.type, lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MulIOp(lhs.type, lhs, rhs).result + return arith.MulIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: @@ -359,6 +360,7 @@ def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: return std.MinUIOp(lhs.type, lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") + def _infer_structured_outs( op_config: LinalgStructuredOpConfig, in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], From f5227851054e70d4b780c21999e031e8b5ec7381 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 13 Oct 2021 10:02:12 +0200 Subject: [PATCH 133/915] [mlir][python] Expose CallSiteLoc Python side This exposes creating a CallSiteLoc with a callee & list of frames for callers. Follows the creation approach in C++ side where a list of frames may be provided. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D111670 --- mlir/lib/Bindings/Python/IRCore.cpp | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3226b7fe1..d53efd9c7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -17,6 +17,7 @@ #include "mlir-c/Debug.h" #include "mlir-c/IR.h" #include "mlir-c/Registration.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include @@ -40,6 +41,9 @@ Returns a Type object or raises a ValueError if the type cannot be parsed. See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; +static const char kContextGetCallSiteLocationDocstring[] = + R"(Gets a Location representing a caller and callsite)"; + static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; @@ -1962,6 +1966,21 @@ void mlir::python::populateIRCore(py::module &m) { }, py::arg("context") = py::none(), "Gets a Location representing an unknown location") + .def_static( + "callsite", + [](PyLocation callee, const std::vector &frames, + DefaultingPyMlirContext context) { + if (frames.empty()) + throw py::value_error("No caller frames provided"); + MlirLocation caller = frames.back().get(); + for (PyLocation frame : + llvm::reverse(llvm::makeArrayRef(frames).drop_back())) + caller = mlirLocationCallSiteGet(frame.get(), caller); + return PyLocation(context->getRef(), + mlirLocationCallSiteGet(callee.get(), caller)); + }, + py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), + kContextGetCallSiteLocationDocstring) .def_static( "file", [](std::string filename, int line, int col, From 0f58e0f6375abd3a3851f61322227cfef77af49c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 11 Oct 2021 18:24:48 +0200 Subject: [PATCH 134/915] [mlir][python] Provide some methods and properties for API completeness When writing the user-facing documentation, I noticed several inconsistencies and asymmetries in the Python API we provide. Fix them by adding: - the `owner` property to regions, similarly to blocks; - the `isinstance` method to any class derived from `PyConcreteAttr`, `PyConcreteValue` and `PyConreteAffineExpr`, similar to `PyConcreteType` to enable `isa`-like calls without having to handle exceptions; - a mechanism to create the first block in the region as we could only create blocks relative to other blocks, with is impossible in an empty region. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111556 --- mlir/lib/Bindings/Python/IRAffine.cpp | 3 +++ mlir/lib/Bindings/Python/IRCore.cpp | 26 ++++++++++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 5 +++++ 3 files changed, 34 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 5314badba..0027b68ee 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -99,6 +99,9 @@ class PyConcreteAffineExpr : public BaseTy { static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init()); + cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { + return DerivedTy::isaFunction(otherAffineExpr); + }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d53efd9c7..7b1c99829 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1548,6 +1548,9 @@ class PyConcreteValue : public PyValue { static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); cls.def(py::init(), py::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }); DerivedTy::bindDerived(cls); } @@ -2248,6 +2251,12 @@ void mlir::python::populateIRCore(py::module &m) { return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") + .def_property_readonly( + "owner", + [](PyRegion &self) { + return self.getParentOperation()->createOpView(); + }, + "Returns the operation owning this region.") .def( "__iter__", [](PyRegion &self) { @@ -2291,6 +2300,23 @@ void mlir::python::populateIRCore(py::module &m) { return PyOperationList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of operations.") + .def_static( + "create_at_start", + [](PyRegion &parent, py::list pyArgTypes) { + parent.checkValid(); + llvm::SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + mlirRegionInsertOwnedBlock(parent, 0, block); + return PyBlock(parent.getParentOperation(), block); + }, + py::arg("parent"), py::arg("pyArgTypes") = py::list(), + "Creates and returns a new Block at the beginning of the given " + "region (with given argument types).") .def( "create_before", [](PyBlock &self, py::args pyArgTypes) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 702870487..ae85ef850 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -533,6 +533,7 @@ class PyRegion { : parentOperation(std::move(parentOperation)), region(region) { assert(!mlirRegionIsNull(region) && "python region cannot be null"); } + operator MlirRegion() const { return region; } MlirRegion get() { return region; } PyOperationRef &getParentOperation() { return parentOperation; } @@ -681,6 +682,9 @@ class PyConcreteAttribute : public BaseTy { auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), pybind11::module_local()); cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { + return DerivedTy::isaFunction(otherAttr); + }); DerivedTy::bindDerived(cls); } @@ -764,6 +768,7 @@ class PyValue { public: PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(parentOperation), value(value) {} + operator MlirValue() const { return value; } MlirValue get() { return value; } PyOperationRef &getParentOperation() { return parentOperation; } From 32e2a55864c1da82b43dd9374c3ff6c2b969f8fb Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 13 Oct 2021 15:20:31 +0200 Subject: [PATCH 135/915] [mlir][python] Add custom constructor for memref load The type can be inferred trivially, but it is currently done as string stitching between ODS and C++ and is not easily exposed to Python. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111712 --- mlir/python/mlir/dialects/_memref_ops_ext.py | 37 ++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 mlir/python/mlir/dialects/_memref_ops_ext.py diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py new file mode 100644 index 000000000..cb25ef105 --- /dev/null +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -0,0 +1,37 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +class LoadOp: + """Specialization for the MemRef load operation.""" + + def __init__(self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, + Sequence[Value]]] = None, + *, + loc=None, + ip=None): + """Creates a memref load operation. + + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + memref_resolved = _get_op_result_or_value(memref) + indices_resolved = [] if indices is None else _get_op_results_or_values( + indices) + return_type = memref_resolved.type + super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) From ccc49e2cba1ab19d6fbb36db5b5bc8dc5220cadc Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 13 Oct 2021 17:29:19 +0200 Subject: [PATCH 136/915] [mlir] fix python bindings cmake --- mlir/python/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ed4eb6c1a..8c60a31b0 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -108,7 +108,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MemRefOps.td - SOURCES dialects/memref.py + SOURCES + dialects/memref.py + dialects/_memref_ops_ext.py DIALECT_NAME memref) declare_mlir_dialect_python_bindings( From 9e57e84f7f94ec64fd63eaa7894c87223a0c8ecf Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Oct 2021 11:33:28 +0200 Subject: [PATCH 137/915] [mlir][python] Fix MemRefType IsAFunction in Python bindings MemRefType was using a wrong `isa` function in the bindings code, which could lead to invalid IR being constructed. Also run the verifier in memref dialect tests. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D111784 --- mlir/lib/Bindings/Python/IRTypes.cpp | 2 +- mlir/python/mlir/dialects/_memref_ops_ext.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 568cca160..fd9f3efe7 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -406,7 +406,7 @@ class PyMemRefLayoutMapList; /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py index cb25ef105..9cc22a21c 100644 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -33,5 +33,5 @@ def __init__(self, memref_resolved = _get_op_result_or_value(memref) indices_resolved = [] if indices is None else _get_op_results_or_values( indices) - return_type = memref_resolved.type + return_type = MemRefType(memref_resolved.type).element_type super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) From e0de9bc65b51062a377f540be325e7d134e5be7b Mon Sep 17 00:00:00 2001 From: rkayaith Date: Mon, 18 Oct 2021 16:00:39 +0200 Subject: [PATCH 138/915] [mlir][python] Add 'loc' property to ops Add a read-only `loc` property to Operation and OpView Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D111972 --- mlir/include/mlir-c/IR.h | 3 +++ mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 16 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 92697a248..2fec0be4d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -357,6 +357,9 @@ MLIR_CAPI_EXPORTED bool mlirOperationEqual(MlirOperation op, /// Gets the context this operation is associated with MLIR_CAPI_EXPORTED MlirContext mlirOperationGetContext(MlirOperation op); +/// Gets the location of the operation. +MLIR_CAPI_EXPORTED MlirLocation mlirOperationGetLocation(MlirOperation op); + /// Gets the type id of the operation. /// Returns null if the operation does not have a registered operation /// description. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b1c99829..ed96bafe4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2143,6 +2143,15 @@ void mlir::python::populateIRCore(py::module &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") + .def_property_readonly( + "location", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + return PyLocation(operation.getContext(), + mlirOperationGetLocation(operation.get())); + }, + "Returns the source location the operation was defined or derived " + "from.") .def("__iter__", [](PyOperationBase &self) { return PyRegionIterator(self.getOperation().getRef()); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ee5a55511..c738198f7 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -346,6 +346,10 @@ MlirContext mlirOperationGetContext(MlirOperation op) { return wrap(unwrap(op)->getContext()); } +MlirLocation mlirOperationGetLocation(MlirOperation op) { + return wrap(unwrap(op)->getLoc()); +} + MlirTypeID mlirOperationGetTypeID(MlirOperation op) { if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) { return wrap(abstractOp->typeID); From bde614907a4f71bf2fe137bbfbdcddf27d64759d Mon Sep 17 00:00:00 2001 From: Vladislav Vinogradov Date: Mon, 11 Oct 2021 18:25:14 +0300 Subject: [PATCH 139/915] [mlir][RFC] Refactor layout representation in MemRefType The change is based on the proposal from the following discussion: https://llvm.discourse.group/t/rfc-memreftype-affine-maps-list-vs-single-item/3968 * Introduce `MemRefLayoutAttr` interface to get `AffineMap` from an `Attribute` (`AffineMapAttr` implements this interface). * Store layout as a single generic `MemRefLayoutAttr`. This change removes the affine map composition feature and related API. Actually, while the `MemRefType` itself supported it, almost none of the upstream can work with more than 1 affine map in `MemRefType`. The introduced `MemRefLayoutAttr` allows to re-implement this feature in a more stable way - via separate attribute class. Also the interface allows to use different layout representations rather than affine maps. For example, the described "stride + offset" form, which is currently supported in ASM parser only, can now be expressed as separate attribute. Reviewed By: ftynse, bondhugula Differential Revision: https://reviews.llvm.org/D111553 --- mlir/include/mlir-c/BuiltinTypes.h | 20 +++---- mlir/lib/Bindings/Python/IRTypes.cpp | 78 ++++++++-------------------- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 42 +++++++-------- 3 files changed, 53 insertions(+), 87 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index a677d4d36..2983627a5 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -229,16 +229,17 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); /// Creates a MemRef type with the given rank and shape, a potentially empty /// list of affine layout maps, the given memory space and element type, in the /// same context as element type. The type is owned by the context. -MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet( - MlirType elementType, intptr_t rank, const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, MlirAttribute memorySpace); +MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGet(MlirType elementType, + intptr_t rank, + const int64_t *shape, + MlirAttribute layout, + MlirAttribute memorySpace); /// Same as "mlirMemRefTypeGet" but returns a nullptr-wrapping MlirType o /// illegal arguments, emitting appropriate diagnostics. MLIR_CAPI_EXPORTED MlirType mlirMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, MlirAffineMap const *affineMaps, - MlirAttribute memorySpace); + MlirAttribute layout, MlirAttribute memorySpace); /// Creates a MemRef type with the given rank, shape, memory space and element /// type in the same context as the element type. The type has no affine maps, @@ -264,12 +265,11 @@ mlirUnrankedMemRefTypeGet(MlirType elementType, MlirAttribute memorySpace); MLIR_CAPI_EXPORTED MlirType mlirUnrankedMemRefTypeGetChecked( MlirLocation loc, MlirType elementType, MlirAttribute memorySpace); -/// Returns the number of affine layout maps in the given MemRef type. -MLIR_CAPI_EXPORTED intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type); +/// Returns the layout of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetLayout(MlirType type); -/// Returns the pos-th affine map of the given MemRef type. -MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, - intptr_t pos); +/// Returns the affine map of the given MemRef type. +MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index fd9f3efe7..1cfd799bf 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -401,8 +401,6 @@ class PyUnrankedTensorType } }; -class PyMemRefLayoutMapList; - /// Ranked MemRef Type subclass - MemRefType. class PyMemRefType : public PyConcreteType { public: @@ -410,26 +408,18 @@ class PyMemRefType : public PyConcreteType { static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; - PyMemRefLayoutMapList getLayout(); - static void bindDerived(ClassTy &c) { c.def_static( "get", [](std::vector shape, PyType &elementType, - std::vector layout, PyAttribute *memorySpace, + PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { - SmallVector maps; - maps.reserve(layout.size()); - for (PyAffineMap &map : layout) - maps.push_back(map); - - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; - - MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), maps.size(), - maps.data(), memSpaceAttr); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); // TODO: Rework error reporting once diagnostic engine is exposed // in C API. if (mlirTypeIsNull(t)) { @@ -444,10 +434,22 @@ class PyMemRefType : public PyConcreteType { return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::list(), py::arg("memory_space") = py::none(), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly("layout", &PyMemRefType::getLayout, - "The list of layout maps of the MemRef type.") + .def_property_readonly( + "layout", + [](PyMemRefType &self) -> PyAttribute { + MlirAttribute layout = mlirMemRefTypeGetLayout(self); + return PyAttribute(self.getContext(), layout); + }, + "The layout of the MemRef type.") + .def_property_readonly( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", [](PyMemRefType &self) -> PyAttribute { @@ -458,41 +460,6 @@ class PyMemRefType : public PyConcreteType { } }; -/// A list of affine layout maps in a memref type. Internally, these are stored -/// as consecutive elements, random access is cheap. Both the type and the maps -/// are owned by the context, no need to worry about lifetime extension. -class PyMemRefLayoutMapList - : public Sliceable { -public: - static constexpr const char *pyClassName = "MemRefLayoutMapList"; - - PyMemRefLayoutMapList(PyMemRefType type, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirMemRefTypeGetNumAffineMaps(type) : length, - step), - memref(type) {} - - intptr_t getNumElements() { return mlirMemRefTypeGetNumAffineMaps(memref); } - - PyAffineMap getElement(intptr_t index) { - return PyAffineMap(memref.getContext(), - mlirMemRefTypeGetAffineMap(memref, index)); - } - - PyMemRefLayoutMapList slice(intptr_t startIndex, intptr_t length, - intptr_t step) { - return PyMemRefLayoutMapList(memref, startIndex, length, step); - } - -private: - PyMemRefType memref; -}; - -PyMemRefLayoutMapList PyMemRefType::getLayout() { - return PyMemRefLayoutMapList(*this); -} - /// Unranked MemRef Type subclass - UnrankedMemRefType. class PyUnrankedMemRefType : public PyConcreteType { @@ -640,7 +607,6 @@ void mlir::python::populateIRTypes(py::module &m) { PyRankedTensorType::bind(m); PyUnrankedTensorType::bind(m); PyMemRefType::bind(m); - PyMemRefLayoutMapList::bind(m); PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index d978f17b9..318b8eb10 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -226,34 +226,35 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, - const int64_t *shape, intptr_t numMaps, - MlirAffineMap const *affineMaps, + const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + return wrap(MemRefType::get( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : unwrap(layout).cast(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, intptr_t rank, const int64_t *shape, - intptr_t numMaps, - MlirAffineMap const *affineMaps, + MlirAttribute layout, MlirAttribute memorySpace) { - SmallVector maps; - (void)unwrapList(numMaps, affineMaps, maps); return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), maps, unwrap(memorySpace))); + unwrap(elementType), + mlirAttributeIsNull(layout) + ? MemRefLayoutAttrInterface() + : unwrap(layout).cast(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace) { - return wrap( - MemRefType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + return wrap(MemRefType::get( + llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + MemRefLayoutAttrInterface(), unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, @@ -262,16 +263,15 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType), llvm::None, unwrap(memorySpace))); + unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); } -intptr_t mlirMemRefTypeGetNumAffineMaps(MlirType type) { - return static_cast( - unwrap(type).cast().getAffineMaps().size()); +MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { + return wrap(unwrap(type).cast().getLayout()); } -MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getAffineMaps()[pos]); +MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { + return wrap(unwrap(type).cast().getLayout().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { From 7e4eff0fb44b059aded414506b12e923fbe94997 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 19 Oct 2021 17:13:54 +0000 Subject: [PATCH 140/915] Fix clang-tidy warnings in MLIR Python bindings (NFC) --- mlir/lib/Bindings/Python/IRAffine.cpp | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 0027b68ee..50a96c8c8 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -555,6 +555,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapCompressUnusedSymbols( maps.data(), maps.size(), compressed.data(), populate); std::vector res; + res.reserve(compressed.size()); for (auto m : compressed) res.push_back(PyAffineMap(context->getRef(), m)); return res; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index ed96bafe4..4fc581b5d 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1976,7 +1976,7 @@ void mlir::python::populateIRCore(py::module &m) { if (frames.empty()) throw py::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); - for (PyLocation frame : + for (const PyLocation &frame : llvm::reverse(llvm::makeArrayRef(frames).drop_back())) caller = mlirLocationCallSiteGet(frame.get(), caller); return PyLocation(context->getRef(), From 0b39b2f8eca0afc984abdf7437d2a20888d9e913 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 23 Oct 2021 08:45:29 -0700 Subject: [PATCH 141/915] Ensure newlines at the end of files (NFC) --- mlir/lib/Bindings/Python/Pass.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index 550ff47c3..3a500d5e8 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -19,4 +19,4 @@ void populatePassManagerSubmodule(pybind11::module &m); } // namespace python } // namespace mlir -#endif // MLIR_BINDINGS_PYTHON_PASS_H \ No newline at end of file +#endif // MLIR_BINDINGS_PYTHON_PASS_H From 944e86714f186fc602a02afec5cbb79dac586d43 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Oct 2021 17:18:28 +0200 Subject: [PATCH 142/915] [mlir] support interfaces in Python bindings Introduce the initial support for operation interfaces in C API and Python bindings. Interfaces are a key component of MLIR's extensibility and should be available in bindings to make use of full potential of MLIR. This initial implementation exposes InferTypeOpInterface all the way to the Python bindings since it can be later used to simplify the operation construction methods by inferring their return types instead of requiring the user to do so. The general infrastructure for binding interfaces is defined and InferTypeOpInterface can be used as an example for binding other interfaces. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111656 --- mlir/include/mlir-c/Interfaces.h | 67 ++++++ mlir/include/mlir/CAPI/Interfaces.h | 18 ++ mlir/lib/Bindings/Python/IRInterfaces.cpp | 240 ++++++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/lib/Bindings/Python/MainModule.cpp | 1 + mlir/lib/CAPI/CMakeLists.txt | 1 + mlir/lib/CAPI/Interfaces/CMakeLists.txt | 5 + mlir/lib/CAPI/Interfaces/Interfaces.cpp | 82 ++++++++ mlir/python/CMakeLists.txt | 40 +++- mlir/python/mlir/dialects/PythonTest.td | 33 --- mlir/python/mlir/dialects/python_test.py | 5 + 11 files changed, 454 insertions(+), 39 deletions(-) create mode 100644 mlir/include/mlir-c/Interfaces.h create mode 100644 mlir/include/mlir/CAPI/Interfaces.h create mode 100644 mlir/lib/Bindings/Python/IRInterfaces.cpp create mode 100644 mlir/lib/CAPI/Interfaces/CMakeLists.txt create mode 100644 mlir/lib/CAPI/Interfaces/Interfaces.cpp delete mode 100644 mlir/python/mlir/dialects/PythonTest.td diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h new file mode 100644 index 000000000..f03dd6ea5 --- /dev/null +++ b/mlir/include/mlir-c/Interfaces.h @@ -0,0 +1,67 @@ +//===-- mlir-c/Interfaces.h - C API to Core MLIR IR interfaces ----*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to MLIR interface classes. It is +// intended to contain interfaces defined in lib/Interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_H +#define MLIR_C_DIALECT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Returns `true` if the given operation implements an interface identified by +/// its TypeID. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID); + +/// Returns `true` if the operation identified by its canonical string name +/// implements the interface identified by its TypeID in the given context. +/// Note that interfaces may be attached to operations in some contexts and not +/// others. +MLIR_CAPI_EXPORTED bool +mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID); + +//===----------------------------------------------------------------------===// +// InferTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple types from functions while +/// transferring ownerhsip to the caller. The first argument is the number of +/// consecutive elements pointed to by the second argument. The third argument +/// is an opaque pointer forwarded to the callback by the caller. +typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); + +/// Infers the return types of the operation identified by its canonical given +/// the arguments that will be supplied to its generic builder. Calls `callback` +/// with the types of inferred arguments, potentially several times, on success. +/// Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_H diff --git a/mlir/include/mlir/CAPI/Interfaces.h b/mlir/include/mlir/CAPI/Interfaces.h new file mode 100644 index 000000000..4154b8c9e --- /dev/null +++ b/mlir/include/mlir/CAPI/Interfaces.h @@ -0,0 +1,18 @@ +//===- Interfaces.h - C API Utils for MLIR interfaces -----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// MLIR interface classes. This file should not be included from C++ code other +// than C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_INTERFACES_H +#define MLIR_CAPI_INTERFACES_H + +#endif // MLIR_CAPI_INTERFACES_H diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp new file mode 100644 index 000000000..c3d41c4d8 --- /dev/null +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -0,0 +1,240 @@ +//===- IRInterfaces.cpp - MLIR IR interfaces pybind -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "IRModule.h" +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Interfaces.h" + +namespace py = pybind11; + +namespace mlir { +namespace python { + +constexpr static const char *constructorDoc = + R"(Creates an interface from a given operation/opview object or from a +subclass of OpView. Raises ValueError if the operation does not implement the +interface.)"; + +constexpr static const char *operationDoc = + R"(Returns an Operation for which the interface was constructed.)"; + +constexpr static const char *opviewDoc = + R"(Returns an OpView subclass _instance_ for which the interface was +constructed)"; + +constexpr static const char *inferReturnTypesDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return types. Raises ValueError on failure.)"; + +/// CRTP base class for Python classes representing MLIR Op interfaces. +/// Interface hierarchies are flat so no base class is expected here. The +/// derived class is expected to define the following static fields: +/// - `const char *pyClassName` - the name of the Python class to create; +/// - `GetTypeIDFunctionTy getInterfaceID` - the function producing the TypeID +/// of the interface. +/// Derived classes may redefine the `bindDerived(ClassTy &)` method to bind +/// interface-specific methods. +/// +/// An interface class may be constructed from either an Operation/OpView object +/// or from a subclass of OpView. In the latter case, only the static interface +/// methods are available, similarly to calling ConcereteOp::staticMethod on the +/// C++ side. Implementations of concrete interfaces can use the `isStatic` +/// method to check whether the interface object was constructed from a class or +/// an operation/opview instance. The `getOpName` always succeeds and returns a +/// canonical name of the operation suitable for lookups. +template +class PyConcreteOpInterface { +protected: + using ClassTy = py::class_; + using GetTypeIDFunctionTy = MlirTypeID (*)(); + +public: + /// Constructs an interface instance from an object that is either an + /// operation or a subclass of OpView. In the latter case, only the static + /// methods of the interface are accessible to the caller. + PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + : obj(object) { + try { + operation = &py::cast(obj); + } catch (py::cast_error &err) { + // Do nothing. + } + + try { + operation = &py::cast(obj).getOperation(); + } catch (py::cast_error &err) { + // Do nothing. + } + + if (operation != nullptr) { + if (!mlirOperationImplementsInterface(*operation, + ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + + MlirIdentifier identifier = mlirOperationGetName(*operation); + MlirStringRef stringRef = mlirIdentifierStr(identifier); + opName = std::string(stringRef.data, stringRef.length); + } else { + try { + opName = obj.attr("OPERATION_NAME").template cast(); + } catch (py::cast_error &err) { + throw py::type_error( + "Op interface does not refer to an operation or OpView class"); + } + + if (!mlirOperationImplementsInterfaceStatic( + mlirStringRefCreate(opName.data(), opName.length()), + context.resolve().get(), ConcreteIface::getInterfaceID())) { + std::string msg = "the operation does not implement "; + throw py::value_error(msg + ConcreteIface::pyClassName); + } + } + } + + /// Creates the Python bindings for this class in the given module. + static void bind(py::module &m) { + py::class_ cls(m, "InferTypeOpInterface", + py::module_local()); + cls.def(py::init(), py::arg("object"), + py::arg("context") = py::none(), constructorDoc) + .def_property_readonly("operation", + &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, + opviewDoc); + ConcreteIface::bindDerived(cls); + } + + /// Hook for derived classes to add class-specific bindings. + static void bindDerived(ClassTy &cls) {} + + /// Returns `true` if this object was constructed from a subclass of OpView + /// rather than from an operation instance. + bool isStatic() { return operation == nullptr; } + + /// Returns the operation instance from which this object was constructed. + /// Throws a type error if this object was constructed from a subclass of + /// OpView. + py::object getOperationObject() { + if (operation == nullptr) { + throw py::type_error("Cannot get an operation from a static interface"); + } + + return operation->getRef().releaseObject(); + } + + /// Returns the opview of the operation instance from which this object was + /// constructed. Throws a type error if this object was constructed form a + /// subclass of OpView. + py::object getOpView() { + if (operation == nullptr) { + throw py::type_error("Cannot get an opview from a static interface"); + } + + return operation->createOpView(); + } + + /// Returns the canonical name of the operation this interface is constructed + /// from. + const std::string &getOpName() { return opName; } + +private: + PyOperation *operation = nullptr; + std::string opName; + py::object obj; +}; + +/// Python wrapper for InterTypeOpInterface. This interface has only static +/// methods. +class PyInferTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredTypes; + PyMlirContext &pyMlirContext; + }; + + /// Appends the types provided as the two first arguments to the user-data + /// structure (expects AppendResultsCallbackData). + static void appendResultsCallback(intptr_t nTypes, MlirType *types, + void *userData) { + auto *data = static_cast(userData); + data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); + for (intptr_t i = 0; i < nTypes; ++i) { + data->inferredTypes.push_back( + PyType(data->pyMlirContext.getRef(), types[i])); + } + } + + /// Given the arguments required to build an operation, attempts to infer its + /// return types. Throws value_error on faliure. + std::vector + inferReturnTypes(llvm::Optional> operands, + llvm::Optional attributes, + llvm::Optional> regions, + DefaultingPyMlirContext context, + DefaultingPyLocation location) { + llvm::SmallVector mlirOperands; + llvm::SmallVector mlirRegions; + + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue &value : *operands) { + mlirOperands.push_back(value); + } + } + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + std::vector inferredTypes; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredTypes, pyContext}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result types"); + } + + return inferredTypes; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypesDoc); + } +}; + +void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } + +} // namespace python +} // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index ae85ef850..59285c01a 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -859,6 +859,7 @@ class PyIntegerSet : public BaseContextObject { void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); +void populateIRInterfaces(pybind11::module &m); void populateIRTypes(pybind11::module &m); } // namespace python diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index cbade532e..5489a4d3e 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -85,6 +85,7 @@ PYBIND11_MODULE(_mlir, m) { populateIRCore(irModule); populateIRAffine(irModule); populateIRAttributes(irModule); + populateIRInterfaces(irModule); populateIRTypes(irModule); // Define and populate PassManager submodule. diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index eed3f38d1..30ccbe94a 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Debug) add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(ExecutionEngine) +add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) diff --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt new file mode 100644 index 000000000..1de5f21d8 --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt @@ -0,0 +1,5 @@ +add_mlir_public_c_api_library(MLIRCAPIInterfaces + Interfaces.cpp + + LINK_LIBS PUBLIC + MLIRInferTypeOpInterface) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp new file mode 100644 index 000000000..315adb5fb --- /dev/null +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -0,0 +1,82 @@ +//===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Interfaces.h" + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; + +bool mlirOperationImplementsInterface(MlirOperation operation, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = + unwrap(operation)->getAbstractOperation(); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, + MlirContext context, + MlirTypeID interfaceTypeID) { + const AbstractOperation *abstractOp = AbstractOperation::lookup( + StringRef(operationName.data, operationName.length), unwrap(context)); + return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); +} + +MlirTypeID mlirInferTypeOpInterfaceTypeID() { + return wrap(InferTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, + void *userData) { + StringRef name(opName.data, opName.length); + const AbstractOperation *abstractOp = + AbstractOperation::lookup(name, unwrap(context)); + if (!abstractOp) + return mlirLogicalResultFailure(); + + llvm::Optional maybeLocation = llvm::None; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + + SmallVector inferredTypes; + if (failed(abstractOp->getInterface()->inferReturnTypes( + unwrap(context), maybeLocation, unwrappedOperands, attributeDict, + unwrappedRegions, inferredTypes))) + return mlirLogicalResultFailure(); + + SmallVector wrappedInferredTypes; + wrappedInferredTypes.reserve(inferredTypes.size()); + for (Type t : inferredTypes) + wrappedInferredTypes.push_back(wrap(t)); + callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 8c60a31b0..54cc51f0b 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -113,12 +113,25 @@ declare_mlir_dialect_python_bindings( dialects/_memref_ops_ext.py DIALECT_NAME memref) -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT MLIRPythonTestSources.Dialects +# TODO: this uses a tablegen file from the test directory and should be +# decoupled from here. +declare_mlir_python_sources( + MLIRPythonSources.Dialects.PythonTest ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - TD_FILE dialects/PythonTest.td - SOURCES dialects/python_test.py - DIALECT_NAME python_test) + ADD_TO_PARENT MLIRPythonSources.Dialects + SOURCES dialects/python_test.py) +set(LLVM_TARGET_DEFINITIONS + "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") +mlir_tablegen( + "dialects/_python_test_ops_gen.py" + -gen-python-op-bindings + -bind-dialect=python_test) +add_public_tablegen_target(PythonTestDialectPyIncGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.PythonTest.ops_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.PythonTest + SOURCES "dialects/_python_test_ops_gen.py") declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -192,6 +205,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core ${PYTHON_SOURCE_DIR}/IRAffine.cpp ${PYTHON_SOURCE_DIR}/IRAttributes.cpp ${PYTHON_SOURCE_DIR}/IRCore.cpp + ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp ${PYTHON_SOURCE_DIR}/IRModule.cpp ${PYTHON_SOURCE_DIR}/IRTypes.cpp ${PYTHON_SOURCE_DIR}/PybindUtils.cpp @@ -201,6 +215,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core EMBED_CAPI_LINK_LIBS MLIRCAPIDebug MLIRCAPIIR + MLIRCAPIInterfaces MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects @@ -297,6 +312,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Transforms MLIRCAPITransforms ) +# TODO: This should not be included in the main Python extension. However, +# putting it into MLIRPythonTestSources along with the dialect declaration +# above confuses Python module loader when running under lit. +declare_mlir_python_extension(MLIRPythonExtension.PythonTest + MODULE_NAME _mlirPythonTest + ADD_TO_PARENT MLIRPythonSources.Dialects + SOURCES + ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect +) + ################################################################################ # Common CAPI dependency DSO. # All python extensions must link through one DSO which exports the CAPI, and @@ -336,7 +365,6 @@ add_mlir_python_modules(MLIRPythonModules MLIRPythonCAPI ) - add_mlir_python_modules(MLIRPythonTestModules ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir" INSTALL_PREFIX "python_packages/mlir_test/mlir" diff --git a/mlir/python/mlir/dialects/PythonTest.td b/mlir/python/mlir/dialects/PythonTest.td deleted file mode 100644 index d3d49395a..000000000 --- a/mlir/python/mlir/dialects/PythonTest.td +++ /dev/null @@ -1,33 +0,0 @@ -//===-- python_test_ops.td - Python test Op definitions ----*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef PYTHON_TEST_OPS -#define PYTHON_TEST_OPS - -include "mlir/Bindings/Python/Attributes.td" -include "mlir/IR/OpBase.td" - -def Python_Test_Dialect : Dialect { - let name = "python_test"; - let cppNamespace = "PythonTest"; -} -class TestOp traits = []> - : Op; - -def AttributedOp : TestOp<"attributed_op"> { - let arguments = (ins I32Attr:$mandatory_i32, - OptionalAttr:$optional_i32, - UnitAttr:$unit); -} - -def PropertyOp : TestOp<"property_op"> { - let arguments = (ins I32Attr:$property, - I32:$idx); -} - -#endif // PYTHON_TEST_OPS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 524db4317..82c01d5a0 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,3 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * + + +def register_python_test_dialect(context, load=True): + from .._mlir_libs import _mlirPythonTest + _mlirPythonTest.register_python_test_dialect(context, load) From 69080130dd62181ea931b49c31152dd5f4be448f Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 14 Oct 2021 17:19:06 +0200 Subject: [PATCH 143/915] [mlir][python] Infer result types in generated constructors whenever possible In several cases, operation result types can be unambiguously inferred from operands and attributes at operation construction time. Stop requiring the user to provide these types as arguments in the ODS-generated constructors in Python bindings. In particular, handle the SameOperandAndResultTypes and FirstAttrDerivedResultType traits as well as InferTypeOpInterface using the recently added interface support. This is a significant usability improvement for IR construction, similar to what C++ ODS provides. Depends On D111656 Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D111811 --- mlir/lib/Bindings/Python/IRModule.h | 132 +++++++++--------- .../dialects/linalg/opdsl/lang/emitter.py | 34 ++--- 2 files changed, 84 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 59285c01a..dac9486c4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -601,6 +601,71 @@ class PyInsertionPoint { llvm::Optional refOperation; PyBlock block; }; +/// Wrapper around the generic MlirType. +/// The lifetime of a type is bound by the PyContext that created it. +class PyType : public BaseContextObject { +public: + PyType(PyMlirContextRef contextRef, MlirType type) + : BaseContextObject(std::move(contextRef)), type(type) {} + bool operator==(const PyType &other); + operator MlirType() const { return type; } + MlirType get() const { return type; } + + /// Gets a capsule wrapping the void* within the MlirType. + pybind11::object getCapsule(); + + /// Creates a PyType from the MlirType wrapped by a capsule. + /// Note that PyType instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirType + /// is taken by calling this function. + static PyType createFromCapsule(pybind11::object capsule); + +private: + MlirType type; +}; + +/// CRTP base classes for Python types that subclass Type and should be +/// castable from it (i.e. via something like IntegerType(t)). +/// By default, type class hierarchies are one level deep (i.e. a +/// concrete type class extends PyType); however, intermediate python-visible +/// base classes can be modeled by specifying a BaseTy. +template +class PyConcreteType : public BaseTy { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + using ClassTy = pybind11::class_; + using IsAFunctionTy = bool (*)(MlirType); + + PyConcreteType() = default; + PyConcreteType(PyMlirContextRef contextRef, MlirType t) + : BaseTy(std::move(contextRef), t) {} + PyConcreteType(PyType &orig) + : PyConcreteType(orig.getContext(), castFrom(orig)) {} + + static MlirType castFrom(PyType &orig) { + if (!DerivedTy::isaFunction(orig)) { + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + + " (from " + origRepr + ")"); + } + return orig; + } + + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); + cls.def_static("isinstance", [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; /// Wrapper around the generic MlirAttribute. /// The lifetime of a type is bound by the PyContext that created it. @@ -685,71 +750,8 @@ class PyConcreteAttribute : public BaseTy { cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; - -/// Wrapper around the generic MlirType. -/// The lifetime of a type is bound by the PyContext that created it. -class PyType : public BaseContextObject { -public: - PyType(PyMlirContextRef contextRef, MlirType type) - : BaseContextObject(std::move(contextRef)), type(type) {} - bool operator==(const PyType &other); - operator MlirType() const { return type; } - MlirType get() const { return type; } - - /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); - - /// Creates a PyType from the MlirType wrapped by a capsule. - /// Note that PyType instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirType - /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); - -private: - MlirType type; -}; - -/// CRTP base classes for Python types that subclass Type and should be -/// castable from it (i.e. via something like IntegerType(t)). -/// By default, type class hierarchies are one level deep (i.e. a -/// concrete type class extends PyType); however, intermediate python-visible -/// base classes can be modeled by specifying a BaseTy. -template -class PyConcreteType : public BaseTy { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - using ClassTy = pybind11::class_; - using IsAFunctionTy = bool (*)(MlirType); - - PyConcreteType() = default; - PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} - PyConcreteType(PyType &orig) - : PyConcreteType(orig.getContext(), castFrom(orig)) {} - - static MlirType castFrom(PyType &orig) { - if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); - } - return orig; - } - - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); + cls.def_property_readonly("type", [](PyAttribute &attr) { + return PyType(attr.getContext(), mlirAttributeGetType(attr)); }); DerivedTy::bindDerived(cls); } diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 1215d0358..2ece9eb92 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -221,7 +221,7 @@ def expression(self, expr: ScalarExpression) -> Value: elif expr.scalar_index: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) - return linalg.IndexOp(IndexType.get(), dim_attr).result + return linalg.IndexOp(dim_attr).result elif expr.scalar_apply: try: fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") @@ -303,61 +303,61 @@ def yield_outputs(self, *output_names: str): def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs.type, lhs, rhs).result + return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs.type, lhs, rhs).result + return arith.AddIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") def _eval_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.ExpOp(x.type, x).result + return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") def _eval_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.LogOp(x.type, x).result + return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") def _eval_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.SubFOp(lhs.type, lhs, rhs).result + return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.SubIOp(lhs.type, lhs, rhs).result + return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") def _eval_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MulFOp(lhs.type, lhs, rhs).result + return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MulIOp(lhs.type, lhs, rhs).result + return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxSIOp(lhs.type, lhs, rhs).result + return std.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs.type, lhs, rhs).result + return std.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxUIOp(lhs.type, lhs, rhs).result + return std.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinSIOp(lhs.type, lhs, rhs).result + return std.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs.type, lhs, rhs).result + return std.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinUIOp(lhs.type, lhs, rhs).result + return std.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") From 38e639732834f6d894cbc72d3d1dbb295d9f930b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 25 Oct 2021 21:11:50 -0700 Subject: [PATCH 144/915] [mlir-c] Avoid compiler warning Setting visibility & static leads to warning about attribute being ignored. Differential Revision: https://reviews.llvm.org/D112507 --- mlir/include/mlir-c/IR.h | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 2fec0be4d..456aa93b2 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -713,9 +713,7 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); //===----------------------------------------------------------------------===// /// Checks whether a type id is null. -MLIR_CAPI_EXPORTED static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { - return !typeID.ptr; -} +static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } /// Checks if two type ids are equal. MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); From 272c8455ed60c3115e8e998d9d70a34e9a372c25 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 20 Oct 2021 19:51:52 +0000 Subject: [PATCH 145/915] Make Python MLIR Operation not iterable The current behavior is conveniently allowing to iterate on the regions of an operation implicitly by exposing an operation as Iterable. However this is also error prone and code that may intend to iterate on the results or the operands could end up "working" apparently instead of throwing a runtime error. The lack of static type checking in Python contributes to the ambiguity here, it seems safer to not do this and require and explicit qualification to iterate (`op.results`, `op.regions`, ...). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D111697 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 ---- mlir/python/mlir/dialects/_builtin_ops_ext.py | 9 +++++++++ mlir/python/mlir/dialects/_ods_common.py | 4 +++- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4fc581b5d..7abd2a1f6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2152,10 +2152,6 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the source location the operation was defined or derived " "from.") - .def("__iter__", - [](PyOperationBase &self) { - return PyRegionIterator(self.getOperation().getRef()); - }) .def( "__str__", [](PyOperationBase &self) { diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index 462850d63..78f8c95c4 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -195,8 +195,17 @@ def decorator(f): # Coerce return values, add ReturnOp and rewrite func type. if return_values is None: return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results else: return_values = list(return_values) std.ReturnOp(return_values) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 95c441865..6bb84e978 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -124,7 +124,7 @@ def get_default_loc_context(location=None): def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value] + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] ) -> _cext.ir.Value: """Returns the given value or the single result of the given op. @@ -136,6 +136,8 @@ def get_op_result_or_value( return arg.operation.result elif isinstance(arg, _cext.ir.Operation): return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] else: assert isinstance(arg, _cext.ir.Value) return arg From 62269279a10e9e5836e9ff83392e36dad3aa0ba4 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Oct 2021 17:14:50 +0000 Subject: [PATCH 146/915] [mlir][python] Segment MLIR Python test dialect to avoid testonly dependency. With https://reviews.llvm.org/rG14c9207063bb00823a5126131e50c93f6e288bd3, the build is broken with -DMLIR_INCLUDE_TESTS=OFF. This patch fixes the build and we may want to do a better fix to the layering in a followup. Differential Revision: https://reviews.llvm.org/D112560 --- mlir/python/CMakeLists.txt | 84 +++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 42 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 54cc51f0b..6a49c773e 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -8,10 +8,6 @@ declare_mlir_python_sources(MLIRPythonSources) declare_mlir_python_sources(MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources) -declare_mlir_python_sources(MLIRPythonTestSources) -declare_mlir_python_sources(MLIRPythonTestSources.Dialects - ADD_TO_PARENT MLIRPythonTestSources) - ################################################################################ # Pure python sources and generated code ################################################################################ @@ -113,26 +109,6 @@ declare_mlir_dialect_python_bindings( dialects/_memref_ops_ext.py DIALECT_NAME memref) -# TODO: this uses a tablegen file from the test directory and should be -# decoupled from here. -declare_mlir_python_sources( - MLIRPythonSources.Dialects.PythonTest - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - ADD_TO_PARENT MLIRPythonSources.Dialects - SOURCES dialects/python_test.py) -set(LLVM_TARGET_DEFINITIONS - "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") -mlir_tablegen( - "dialects/_python_test_ops_gen.py" - -gen-python-op-bindings - -bind-dialect=python_test) -add_public_tablegen_target(PythonTestDialectPyIncGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.PythonTest.ops_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.PythonTest - SOURCES "dialects/_python_test_ops_gen.py") - declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -312,19 +288,48 @@ declare_mlir_python_extension(MLIRPythonExtension.Transforms MLIRCAPITransforms ) -# TODO: This should not be included in the main Python extension. However, +# TODO: Figure out how to put this in the test tree. +# This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration # above confuses Python module loader when running under lit. -declare_mlir_python_extension(MLIRPythonExtension.PythonTest - MODULE_NAME _mlirPythonTest - ADD_TO_PARENT MLIRPythonSources.Dialects - SOURCES - ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIPythonTestDialect -) +set(_ADDL_TEST_SOURCES) +if(MLIR_INCLUDE_TESTS) + set(_ADDL_TEST_SOURCES MLIRPythonTestSources) + declare_mlir_python_sources(MLIRPythonTestSources) + declare_mlir_python_sources(MLIRPythonTestSources.Dialects + ADD_TO_PARENT MLIRPythonTestSources) + + # TODO: this uses a tablegen file from the test directory and should be + # decoupled from here. + declare_mlir_python_sources( + MLIRPythonTestSources.Dialects.PythonTest + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonTestSources.Dialects + SOURCES dialects/python_test.py) + set(LLVM_TARGET_DEFINITIONS + "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") + mlir_tablegen( + "dialects/_python_test_ops_gen.py" + -gen-python-op-bindings + -bind-dialect=python_test) + add_public_tablegen_target(PythonTestDialectPyIncGen) + declare_mlir_python_sources( + MLIRPythonTestSources.Dialects.PythonTest.ops_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest + SOURCES "dialects/_python_test_ops_gen.py") + + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension + MODULE_NAME _mlirPythonTest + ADD_TO_PARENT MLIRPythonTestSources.Dialects + SOURCES + ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) +endif() ################################################################################ # Common CAPI dependency DSO. @@ -347,6 +352,7 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.AllPassesRegistration + ${_ADDL_TEST_SOURCES} ) ################################################################################ @@ -361,13 +367,7 @@ add_mlir_python_modules(MLIRPythonModules MLIRPythonSources MLIRPythonExtension.AllPassesRegistration MLIRPythonCAPIHeaderSources + ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS MLIRPythonCAPI ) - -add_mlir_python_modules(MLIRPythonTestModules - ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_test/mlir" - INSTALL_PREFIX "python_packages/mlir_test/mlir" - DECLARED_SOURCES - MLIRPythonTestSources - ) From 050f6d85048c361ab72e8800beb494086ae35443 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 29 Oct 2021 15:11:09 +0200 Subject: [PATCH 147/915] [mlir][python] Add a __contains__ method to the python bindings for DictionaryAttr. This makes it easier to check in python whether a certain attribute is there. Differential Revision: https://reviews.llvm.org/D112814 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 066350a0a..7db90ec8d 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -698,7 +698,13 @@ class PyDictAttribute : public PyConcreteAttribute { intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } + bool dunderContains(const std::string &name) { + return !mlirAttributeIsNull( + mlirDictionaryAttrGetElementByName(*this, toMlirStringRef(name))); + } + static void bindDerived(ClassTy &c) { + c.def("__contains__", &PyDictAttribute::dunderContains); c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", From 740b2daf28001ae5786451658e8075910e6b0870 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Sun, 31 Oct 2021 09:37:20 +0100 Subject: [PATCH 148/915] [mlir][python] allow for detaching operations from a block Provide support for removing an operation from the block that contains it and moving it back to detached state. This allows for the operation to be moved to a different block, a common IR manipulation for, e.g., module merging. Also fix a potential one-past-end iterator dereference in Operation::moveAfter discovered in the process. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D112700 --- mlir/include/mlir-c/IR.h | 17 +++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 53 +++++++++++++++++++++++++++-- mlir/lib/Bindings/Python/IRModule.h | 20 ++++++++++- mlir/lib/CAPI/IR/IR.cpp | 10 ++++++ 4 files changed, 97 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 456aa93b2..ca0c45224 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -346,6 +346,10 @@ MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); /// Takes an operation owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirOperationDestroy(MlirOperation op); +/// Removes the given operation from its parent block. The operation is not +/// destroyed. The ownership of the operation is transferred to the caller. +MLIR_CAPI_EXPORTED void mlirOperationRemoveFromParent(MlirOperation op); + /// Checks whether the underlying operation is null. static inline bool mlirOperationIsNull(MlirOperation op) { return !op.ptr; } @@ -455,6 +459,19 @@ MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); /// Verify the operation and return true if it passes, false if it fails. MLIR_CAPI_EXPORTED bool mlirOperationVerify(MlirOperation op); +/// Moves the given operation immediately after the other operation in its +/// parent block. The given operation may be owned by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, + MlirOperation other); + +/// Moves the given operation immediately before the other operation in its +/// parent block. The given operation may be owner by the caller or by its +/// current block. The other operation must belong to a block. In any case, the +/// ownership is transferred to the block of the other operation. +MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, + MlirOperation other); //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7abd2a1f6..d47d06a3a 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -875,6 +875,24 @@ py::object PyOperationBase::getAsm(bool binary, return fileObject.attr("getvalue")(); } +void PyOperationBase::moveAfter(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveAfter(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + +void PyOperationBase::moveBefore(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + mlirOperationMoveBefore(operation, otherOp); + operation.parentKeepAlive = otherOp.parentKeepAlive; +} + llvm::Optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) @@ -2185,7 +2203,25 @@ void mlir::python::populateIRCore(py::module &m) { return mlirOperationVerify(self.getOperation()); }, "Verify the operation and return true if it passes, false if it " - "fails."); + "fails.") + .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), + "Puts self immediately after the other operation in its parent " + "block.") + .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), + "Puts self immediately before the other operation in its parent " + "block.") + .def( + "detach_from_parent", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + if (!operation.isAttached()) + throw py::value_error("Detached operation has no parent."); + + operation.detachFromParent(); + return operation.createOpView(); + }, + "Detaches the operation from its parent block."); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), @@ -2380,7 +2416,20 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.getUserData()); return printAccum.join(); }, - "Returns the assembly form of the block."); + "Returns the assembly form of the block.") + .def( + "append", + [](PyBlock &self, PyOperationBase &operation) { + if (operation.getOperation().isAttached()) + operation.getOperation().detachFromParent(); + + MlirOperation mlirOperation = operation.getOperation().get(); + mlirBlockAppendOwnedOperation(self.get(), mlirOperation); + operation.getOperation().setAttached( + self.getParentOperation().getObject()); + }, + "Appends an operation to this block. If the operation is currently " + "in another block, it will be moved."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index dac9486c4..73924fc74 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -399,6 +399,10 @@ class PyOperationBase { bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + /// Moves the operation before or after the other operation. + void moveAfter(PyOperationBase &other); + void moveBefore(PyOperationBase &other); + /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; @@ -428,6 +432,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject { createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Detaches the operation from its parent block and updates its state + /// accordingly. + void detachFromParent() { + mlirOperationRemoveFromParent(getOperation()); + setDetached(); + parentKeepAlive = pybind11::object(); + } + /// Gets the backing operation. operator MlirOperation() const { return get(); } MlirOperation get() const { @@ -441,10 +453,14 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } bool isAttached() { return attached; } - void setAttached() { + void setAttached(pybind11::object parent = pybind11::object()) { assert(!attached && "operation already attached"); attached = true; } + void setDetached() { + assert(attached && "operation already detached"); + attached = false; + } void checkValid() const; /// Gets the owning block or raises an exception if the operation has no @@ -495,6 +511,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject { pybind11::object parentKeepAlive; bool attached = true; bool valid = true; + + friend class PyOperationBase; }; /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c738198f7..6f617dc19 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -338,6 +338,8 @@ MlirOperation mlirOperationClone(MlirOperation op) { void mlirOperationDestroy(MlirOperation op) { unwrap(op)->erase(); } +void mlirOperationRemoveFromParent(MlirOperation op) { unwrap(op)->remove(); } + bool mlirOperationEqual(MlirOperation op, MlirOperation other) { return unwrap(op) == unwrap(other); } @@ -451,6 +453,14 @@ bool mlirOperationVerify(MlirOperation op) { return succeeded(verify(unwrap(op))); } +void mlirOperationMoveAfter(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveAfter(unwrap(other)); +} + +void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { + return unwrap(op)->moveBefore(unwrap(other)); +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// From d3c063253530bbb995ae2abcad9162d463b8ac95 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 2 Nov 2021 12:39:36 +0100 Subject: [PATCH 149/915] [mlir] provide C API and Python bindings for symbol tables Symbol tables are a largely useful top-level IR construct, for example, they make it easy to access functions in a module by name instead of traversing the list of module's operations to find the corresponding function. Depends On D112886 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D112821 --- mlir/include/mlir-c/IR.h | 42 +++++++++++++++++++ mlir/include/mlir-c/Support.h | 4 ++ mlir/include/mlir/CAPI/IR.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 65 +++++++++++++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 34 +++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 33 +++++++++++++++ mlir/lib/CAPI/IR/Support.cpp | 6 +++ 7 files changed, 185 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index ca0c45224..161019125 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -54,6 +54,7 @@ DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); +DEFINE_C_API_STRUCT(MlirSymbolTable, void); DEFINE_C_API_STRUCT(MlirAttribute, const void); DEFINE_C_API_STRUCT(MlirIdentifier, const void); @@ -738,6 +739,47 @@ MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); /// Returns the hash value of the type id. MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +/// Returns the name of the attribute used to store symbol names compatible with +/// symbol tables. +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(); + +/// Creates a symbol table for the given operation. If the operation does not +/// have the SymbolTable trait, returns a null symbol table. +MLIR_CAPI_EXPORTED MlirSymbolTable +mlirSymbolTableCreate(MlirOperation operation); + +/// Returns true if the symbol table is null. +static inline bool mlirSymbolTableIsNull(MlirSymbolTable symbolTable) { + return !symbolTable.ptr; +} + +/// Destroys the symbol table created with mlirSymbolTableCreate. This does not +/// affect the operations in the table. +MLIR_CAPI_EXPORTED void mlirSymbolTableDestroy(MlirSymbolTable symbolTable); + +/// Looks up a symbol with the given name in the given symbol table and returns +/// the operation that corresponds to the symbol. If the symbol cannot be found, +/// returns a null operation. +MLIR_CAPI_EXPORTED MlirOperation +mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirStringRef name); + +/// Inserts the given operation into the given symbol table. The operation must +/// have the symbol trait. If the symbol table already has a symbol with the +/// same name, renames the symbol being inserted to ensure name uniqueness. Note +/// that this does not move the operation itself into the block of the symbol +/// table operation, this should be done separately. Returns the name of the +/// symbol after insertion. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation); + +/// Removes the given operation from the symbol table and erases it. +MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 315f6c456..f20e58fe6 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -79,6 +79,10 @@ inline static MlirStringRef mlirStringRefCreate(const char *str, MLIR_CAPI_EXPORTED MlirStringRef mlirStringRefCreateFromCString(const char *str); +/// Returns true if two string references are equal, false otherwise. +MLIR_CAPI_EXPORTED bool mlirStringRefEqual(MlirStringRef string, + MlirStringRef other); + /// A callback for returning string references. /// /// This function is called back by the functions that need to return a diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index d5e961367..a864175d0 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -27,6 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) +DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable); DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d47d06a3a..8f451cf34 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1530,6 +1530,57 @@ PyValue PyValue::createFromCapsule(pybind11::object capsule) { return PyValue(ownerRef, value); } +//------------------------------------------------------------------------------ +// PySymbolTable. +//------------------------------------------------------------------------------ + +PySymbolTable::PySymbolTable(PyOperationBase &operation) + : operation(operation.getOperation().getRef()) { + symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); + if (mlirSymbolTableIsNull(symbolTable)) { + throw py::cast_error("Operation is not a Symbol Table."); + } +} + +py::object PySymbolTable::dunderGetItem(const std::string &name) { + operation->checkValid(); + MlirOperation symbol = mlirSymbolTableLookup( + symbolTable, mlirStringRefCreate(name.data(), name.length())); + if (mlirOperationIsNull(symbol)) + throw py::key_error("Symbol '" + name + "' not in the symbol table."); + + return PyOperation::forOperation(operation->getContext(), symbol, + operation.getObject()) + ->createOpView(); +} + +void PySymbolTable::erase(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + mlirSymbolTableErase(symbolTable, symbol.getOperation().get()); + // The operation is also erased, so we must invalidate it. There may be Python + // references to this operation so we don't want to delete it from the list of + // live operations here. + symbol.getOperation().valid = false; +} + +void PySymbolTable::dunderDel(const std::string &name) { + py::object operation = dunderGetItem(name); + erase(py::cast(operation)); +} + +PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { + operation->checkValid(); + symbol.getOperation().checkValid(); + MlirAttribute symbolAttr = mlirOperationGetAttributeByName( + symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); + if (mlirAttributeIsNull(symbolAttr)) + throw py::value_error("Expected operation to have a symbol name."); + return PyAttribute( + symbol.getOperation().getContext(), + mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2670,6 +2721,20 @@ void mlir::python::populateIRCore(py::module &m) { PyBlockArgument::bind(m); PyOpResult::bind(m); + //---------------------------------------------------------------------------- + // Mapping of SymbolTable. + //---------------------------------------------------------------------------- + py::class_(m, "SymbolTable", py::module_local()) + .def(py::init()) + .def("__getitem__", &PySymbolTable::dunderGetItem) + .def("insert", &PySymbolTable::insert) + .def("erase", &PySymbolTable::erase) + .def("__delitem__", &PySymbolTable::dunderDel) + .def("__contains__", [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }); + // Container bindings. PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 73924fc74..eb5c2385a 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -32,6 +32,7 @@ class DefaultingPyMlirContext; class PyModule; class PyOperation; class PyType; +class PySymbolTable; class PyValue; /// Template for a reference to a concrete type which captures a python @@ -513,6 +514,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { bool valid = true; friend class PyOperationBase; + friend class PySymbolTable; }; /// A PyOpView is equivalent to the C++ "Op" wrappers: these are the basis for @@ -876,6 +878,38 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet integerSet; }; +/// Bindings for MLIR symbol tables. +class PySymbolTable { +public: + /// Constructs a symbol table for the given operation. + explicit PySymbolTable(PyOperationBase &operation); + + /// Destroys the symbol table. + ~PySymbolTable() { mlirSymbolTableDestroy(symbolTable); } + + /// Returns the symbol (opview) with the given name, throws if there is no + /// such symbol in the table. + pybind11::object dunderGetItem(const std::string &name); + + /// Removes the given operation from the symbol table and erases it. + void erase(PyOperationBase &symbol); + + /// Removes the operation with the given name from the symbol table and erases + /// it, throws if there is no such symbol in the table. + void dunderDel(const std::string &name); + + /// Inserts the given operation into the symbol table. The operation must have + /// the symbol trait. + PyAttribute insert(PyOperationBase &symbol); + + /// Casts the bindings class into the C API structure. + operator MlirSymbolTable() { return symbolTable; } + +private: + PyOperationRef operation; + MlirSymbolTable symbolTable; +}; + void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6f617dc19..13490b342 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -763,3 +763,36 @@ bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { size_t mlirTypeIDHashValue(MlirTypeID typeID) { return hash_value(unwrap(typeID)); } + +//===----------------------------------------------------------------------===// +// Symbol and SymbolTable API. +//===----------------------------------------------------------------------===// + +MlirStringRef mlirSymbolTableGetSymbolAttributeName() { + return wrap(SymbolTable::getSymbolAttrName()); +} + +MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { + if (!unwrap(operation)->hasTrait()) + return wrap(static_cast(nullptr)); + return wrap(new SymbolTable(unwrap(operation))); +} + +void mlirSymbolTableDestroy(MlirSymbolTable symbolTable) { + delete unwrap(symbolTable); +} + +MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, + MlirStringRef name) { + return wrap(unwrap(symbolTable)->lookup(StringRef(name.data, name.length))); +} + +MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, + MlirOperation operation) { + return wrap(unwrap(symbolTable)->insert(unwrap(operation))); +} + +void mlirSymbolTableErase(MlirSymbolTable symbolTable, + MlirOperation operation) { + unwrap(symbolTable)->erase(unwrap(operation)); +} diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index e4b409906..b6e1f9180 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -7,9 +7,15 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Support.h" +#include "llvm/ADT/StringRef.h" #include MlirStringRef mlirStringRefCreateFromCString(const char *str) { return mlirStringRefCreate(str, strlen(str)); } + +bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { + return llvm::StringRef(string.data, string.length) == + llvm::StringRef(other.data, other.length); +} From f75abdc7b84c40f7f0d479ad10594327073c524b Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 2 Nov 2021 15:15:03 +0100 Subject: [PATCH 150/915] [mlir] drop spurious semicolon --- mlir/include/mlir/CAPI/IR.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index a864175d0..8366b0bce 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -27,7 +27,7 @@ DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) -DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable); +DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) From 0c07517374b3e99d274835a09079d3162390353f Mon Sep 17 00:00:00 2001 From: Kirill Stoimenov Date: Tue, 2 Nov 2021 21:38:47 +0000 Subject: [PATCH 151/915] [mlir] Fixed a typo. Reviewed By: kda Differential Revision: https://reviews.llvm.org/D113053 --- mlir/include/mlir-c/Interfaces.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index f03dd6ea5..878628342 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -45,7 +45,7 @@ mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); /// These callbacks are used to return multiple types from functions while -/// transferring ownerhsip to the caller. The first argument is the number of +/// transferring ownership to the caller. The first argument is the number of /// consecutive elements pointed to by the second argument. The third argument /// is an opaque pointer forwarded to the callback by the caller. typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); From 01f936e9e3f2e72d88ae84455d39ea347698eb11 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Tue, 2 Nov 2021 17:04:42 +0100 Subject: [PATCH 152/915] [mlir][python] Make Operation and Value hashable This allows operations and values to be used as dict keys Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D112669 --- mlir/lib/Bindings/Python/IRCore.cpp | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8f451cf34..d465c1382 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2171,6 +2171,10 @@ void mlir::python::populateIRCore(py::module &m) { }) .def("__eq__", [](PyOperationBase &self, py::object other) { return false; }) + .def("__hash__", + [](PyOperationBase &self) { + return static_cast(llvm::hash_value(&self.getOperation())); + }) .def_property_readonly("attributes", [](PyOperationBase &self) { return PyOpAttributeMap( @@ -2558,7 +2562,10 @@ void mlir::python::populateIRCore(py::module &m) { .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) - .def("__hash__", [](PyAttribute &self) { return (size_t)self.get().ptr; }) + .def("__hash__", + [](PyAttribute &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def( "dump", [](PyAttribute &self) { mlirAttributeDump(self); }, kDumpDocstring) @@ -2652,7 +2659,10 @@ void mlir::python::populateIRCore(py::module &m) { "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) .def("__eq__", [](PyType &self, py::object &other) { return false; }) - .def("__hash__", [](PyType &self) { return (size_t)self.get().ptr; }) + .def("__hash__", + [](PyType &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def( "dump", [](PyType &self) { mlirTypeDump(self); }, kDumpDocstring) .def( @@ -2703,6 +2713,10 @@ void mlir::python::populateIRCore(py::module &m) { return self.get().ptr == other.get().ptr; }) .def("__eq__", [](PyValue &self, py::object other) { return false; }) + .def("__hash__", + [](PyValue &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def( "__str__", [](PyValue &self) { From 7903981d8e2df0a9b6d0a28b30c72d9986aa552c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 2 Nov 2021 14:15:25 +0100 Subject: [PATCH 153/915] [mlir][python] improve usability of Python affine construct bindings - Provide the operator overloads for constructing (semi-)affine expressions in Python by combining existing expressions with constants. - Make AffineExpr, AffineMap and IntegerSet hashable in Python. - Expose the AffineExpr composition functionality. Reviewed By: gysit, aoyal Differential Revision: https://reviews.llvm.org/D113010 --- mlir/include/mlir-c/AffineExpr.h | 6 ++ mlir/lib/Bindings/Python/IRAffine.cpp | 142 ++++++++++++++++++++++++-- mlir/lib/CAPI/IR/AffineExpr.cpp | 5 + 3 files changed, 142 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h index 5516f2908..14e951dde 100644 --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -39,6 +39,8 @@ DEFINE_C_API_STRUCT(MlirAffineExpr, const void); #undef DEFINE_C_API_STRUCT +struct MlirAffineMap; + /// Gets the context that owns the affine expression. MLIR_CAPI_EXPORTED MlirContext mlirAffineExprGetContext(MlirAffineExpr affineExpr); @@ -86,6 +88,10 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsMultipleOf(MlirAffineExpr affineExpr, MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, intptr_t position); +/// Composes the given map with the given expression. +MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose( + MlirAffineExpr affineExpr, struct MlirAffineMap affineMap); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 50a96c8c8..da80cda9c 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -205,6 +205,18 @@ class PyAffineAddExpr return PyAffineAddExpr(lhs.getContext(), expr); } + static PyAffineAddExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineAddExpr(lhs.getContext(), expr); + } + + static PyAffineAddExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineAddExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineAddExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineAddExpr::get); } @@ -222,6 +234,18 @@ class PyAffineMulExpr return PyAffineMulExpr(lhs.getContext(), expr); } + static PyAffineMulExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineMulExpr(lhs.getContext(), expr); + } + + static PyAffineMulExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineMulExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineMulExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineMulExpr::get); } @@ -239,6 +263,18 @@ class PyAffineModExpr return PyAffineModExpr(lhs.getContext(), expr); } + static PyAffineModExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineModExpr(lhs.getContext(), expr); + } + + static PyAffineModExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineModExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineModExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineModExpr::get); } @@ -256,6 +292,18 @@ class PyAffineFloorDivExpr return PyAffineFloorDivExpr(lhs.getContext(), expr); } + static PyAffineFloorDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineFloorDivExpr(lhs.getContext(), expr); + } + + static PyAffineFloorDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineFloorDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineFloorDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineFloorDivExpr::get); } @@ -273,6 +321,18 @@ class PyAffineCeilDivExpr return PyAffineCeilDivExpr(lhs.getContext(), expr); } + static PyAffineCeilDivExpr getRHSConstant(PyAffineExpr lhs, intptr_t rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + lhs, mlirAffineConstantExprGet(mlirAffineExprGetContext(lhs), rhs)); + return PyAffineCeilDivExpr(lhs.getContext(), expr); + } + + static PyAffineCeilDivExpr getLHSConstant(intptr_t lhs, PyAffineExpr rhs) { + MlirAffineExpr expr = mlirAffineCeilDivExprGet( + mlirAffineConstantExprGet(mlirAffineExprGetContext(rhs), lhs), rhs); + return PyAffineCeilDivExpr(rhs.getContext(), expr); + } + static void bindDerived(ClassTy &c) { c.def_static("get", &PyAffineCeilDivExpr::get); } @@ -435,17 +495,19 @@ void mlir::python::populateIRAffine(py::module &m) { .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) - .def("__add__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineAddExpr::get(self, other); - }) - .def("__mul__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineMulExpr::get(self, other); - }) - .def("__mod__", - [](PyAffineExpr &self, PyAffineExpr &other) { - return PyAffineModExpr::get(self, other); + .def("__add__", &PyAffineAddExpr::get) + .def("__add__", &PyAffineAddExpr::getRHSConstant) + .def("__radd__", &PyAffineAddExpr::getRHSConstant) + .def("__mul__", &PyAffineMulExpr::get) + .def("__mul__", &PyAffineMulExpr::getRHSConstant) + .def("__rmul__", &PyAffineMulExpr::getRHSConstant) + .def("__mod__", &PyAffineModExpr::get) + .def("__mod__", &PyAffineModExpr::getRHSConstant) + .def("__rmod__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineModExpr::get( + PyAffineConstantExpr::get(other, *self.getContext().get()), + self); }) .def("__sub__", [](PyAffineExpr &self, PyAffineExpr &other) { @@ -454,6 +516,17 @@ void mlir::python::populateIRAffine(py::module &m) { return PyAffineAddExpr::get(self, PyAffineMulExpr::get(negOne, other)); }) + .def("__sub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::get( + self, + PyAffineConstantExpr::get(-other, *self.getContext().get())); + }) + .def("__rsub__", + [](PyAffineExpr &self, intptr_t other) { + return PyAffineAddExpr::getLHSConstant( + other, PyAffineMulExpr::getLHSConstant(-1, self)); + }) .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", @@ -474,24 +547,63 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineExpr &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) + .def("compose", + [](PyAffineExpr &self, PyAffineMap &other) { + return PyAffineExpr(self.getContext(), + mlirAffineExprCompose(self, other)); + }) .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") + .def_static("get_add", &PyAffineAddExpr::getLHSConstant, + "Gets an affine expression containing a sum of a constant " + "and another expression.") + .def_static("get_add", &PyAffineAddExpr::getRHSConstant, + "Gets an affine expression containing a sum of an expression " + "and a constant.") .def_static( "get_mul", &PyAffineMulExpr::get, "Gets an affine expression containing a product of two expressions.") + .def_static("get_mul", &PyAffineMulExpr::getLHSConstant, + "Gets an affine expression containing a product of a " + "constant and another expression.") + .def_static("get_mul", &PyAffineMulExpr::getRHSConstant, + "Gets an affine expression containing a product of an " + "expression and a constant.") .def_static("get_mod", &PyAffineModExpr::get, "Gets an affine expression containing the modulo of dividing " "one expression by another.") + .def_static("get_mod", &PyAffineModExpr::getLHSConstant, + "Gets a semi-affine expression containing the modulo of " + "dividing a constant by an expression.") + .def_static("get_mod", &PyAffineModExpr::getRHSConstant, + "Gets an affine expression containing the module of dividing" + "an expression by a constant.") .def_static("get_floor_div", &PyAffineFloorDivExpr::get, "Gets an affine expression containing the rounded-down " "result of dividing one expression by another.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-down " + "result of dividing a constant by an expression.") + .def_static("get_floor_div", &PyAffineFloorDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-down " + "result of dividing an expression by a constant.") .def_static("get_ceil_div", &PyAffineCeilDivExpr::get, "Gets an affine expression containing the rounded-up result " "of dividing one expression by another.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getLHSConstant, + "Gets a semi-affine expression containing the rounded-up " + "result of dividing a constant by an expression.") + .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, + "Gets an affine expression containing the rounded-up result " + "of dividing an expression by a constant.") .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), py::arg("context") = py::none(), "Gets a constant affine expression with the given value.") @@ -542,6 +654,10 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyAffineMap &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_static("compress_unused_symbols", [](py::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; @@ -714,6 +830,10 @@ void mlir::python::populateIRAffine(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def("__hash__", + [](PyIntegerSet &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 2d8bc3ce5..5b25ab533 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -56,6 +56,11 @@ bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, return unwrap(affineExpr).isFunctionOfDim(position); } +MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, + MlirAffineMap affineMap) { + return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// From 810ccdff6e24509fe56a90cb33da5a8dc12704d3 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 2 Nov 2021 16:44:37 +0100 Subject: [PATCH 154/915] [mlir][python] expose the shape property of shaped types This has been missing in the original definition of shaped types. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D113025 --- mlir/lib/Bindings/Python/IRTypes.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 1cfd799bf..89fdb1f06 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -284,6 +284,19 @@ class PyShapedType : public PyConcreteType { }, "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); + c.def_property_readonly( + "shape", + [](PyShapedType &self) { + self.requireHasRank(); + + std::vector shape; + int64_t rank = mlirShapedTypeGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type as a list of integers."); } private: From 5ae5c64cf5de299e54733b9381b53c12f2706cfd Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 5 Nov 2021 12:05:02 +0100 Subject: [PATCH 155/915] [mlir][python] fix constructor generation for optional operands in presence of segment attribute The ODS-based Python op bindings generator has been generating incorrect specification of the operand segment in presence if both optional and variadic operand groups: optional groups were treated as variadic whereas they require separate treatement. Make sure it is the case. Also harden the tests around generated op constructors as they could hitherto accept the code for both optional and variadic arguments. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D113259 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d465c1382..cf59a67f9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1153,7 +1153,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, throw py::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + - "result segments but was provided " + + " result segments but was provided " + llvm::Twine(resultTypeList.size())) .str()); } @@ -1164,7 +1164,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto resultType = py::cast(std::get<0>(it.value())); + auto *resultType = py::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); From 53a0b7e99e3769e82b644b5fe6090719fb4bbc8b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 9 Nov 2021 00:05:55 +0000 Subject: [PATCH 156/915] [mlir] Refactor ElementsAttr's value access API There are several aspects of the API that either aren't easy to use, or are deceptively easy to do the wrong thing. The main change of this commit is to remove all of the `getValue`/`getFlatValue` from ElementsAttr and instead provide operator[] methods on the ranges returned by `getValues`. This provides a much more convenient API for the value ranges. It also removes the easy-to-be-inefficient nature of getValue/getFlatValue, which under the hood would construct a new range for the type `T`. Constructing a range is not necessarily cheap in all cases, and could lead to very poor performance if used within a loop; i.e. if you were to naively write something like: ``` DenseElementsAttr attr = ...; for (int i = 0; i < size; ++i) { // We are internally rebuilding the APFloat value range on each iteration!! APFloat it = attr.getFlatValue(i); } ``` Differential Revision: https://reviews.llvm.org/D113229 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 28 ++++++++++++++------------ 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 3b15212e3..8d6c4ccf6 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -288,8 +288,9 @@ bool mlirAttributeIsAElements(MlirAttribute attr) { MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr).cast().getValue( - llvm::makeArrayRef(idxs, rank))); + return wrap(unwrap(attr) + .cast() + .getValues()[llvm::makeArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, @@ -482,7 +483,8 @@ bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getSplatValue()); + return wrap( + unwrap(attr).cast().getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { return unwrap(attr).cast().getSplatValue(); @@ -520,36 +522,36 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getFlatValue(pos); + return unwrap(attr).cast().getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getFlatValue(pos)); + unwrap(attr).cast().getValues()[pos]); } //===----------------------------------------------------------------------===// From 541a78fb6bdcc764d9a4754c768787306c5edf4f Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 9 Nov 2021 17:52:56 -0800 Subject: [PATCH 157/915] [mlir-c] Add Region iterators matching Block & Operation ones Enables using the same iterator interface to these even though underlying storage is different. Differential Revision: https://reviews.llvm.org/D113512 --- mlir/include/mlir-c/IR.h | 7 +++++++ mlir/lib/CAPI/IR/IR.cpp | 16 ++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 161019125..6c1a92cea 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -517,6 +517,13 @@ MLIR_CAPI_EXPORTED void mlirRegionInsertOwnedBlockBefore(MlirRegion region, MlirBlock reference, MlirBlock block); +/// Returns first region attached to the operation. +MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetFirstRegion(MlirOperation op); + +/// Returns the region immediately following the given region in its parent +/// operation. +MLIR_CAPI_EXPORTED MlirRegion mlirRegionGetNextInOperation(MlirRegion region); + //===----------------------------------------------------------------------===// // Block API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 13490b342..8bed10a9d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -379,6 +379,22 @@ MlirRegion mlirOperationGetRegion(MlirOperation op, intptr_t pos) { return wrap(&unwrap(op)->getRegion(static_cast(pos))); } +MlirRegion mlirOperationGetFirstRegion(MlirOperation op) { + Operation *cppOp = unwrap(op); + if (cppOp->getNumRegions() == 0) + return wrap(static_cast(nullptr)); + return wrap(&cppOp->getRegion(0)); +} + +MlirRegion mlirRegionGetNextInOperation(MlirRegion region) { + Region *cppRegion = unwrap(region); + Operation *parent = cppRegion->getParentOp(); + intptr_t next = cppRegion->getRegionNumber() + 1; + if (parent->getNumRegions() > next) + return wrap(&parent->getRegion(next)); + return wrap(static_cast(nullptr)); +} + MlirOperation mlirOperationGetNextInBlock(MlirOperation op) { return wrap(unwrap(op)->getNextNode()); } From 86ae8f11568bd7fe946715be22c7f9e70f16ec4b Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 11 Nov 2021 01:44:58 +0000 Subject: [PATCH 158/915] [mlir] Replace usages of Identifier with StringAttr Identifier and StringAttr essentially serve the same purpose, i.e. to hold a string value. Keeping these seemingly identical pieces of functionality separate has caused problems in certain situations: * Identifier has nice accessors that StringAttr doesn't * Identifier can't be used as an Attribute, meaning strings are often duplicated between Identifier/StringAttr (e.g. in PDL) The only thing that Identifier has that StringAttr doesn't is support for caching a dialect that is referenced by the string (e.g. dialect.foo). This functionality is added to StringAttr, as this is useful for StringAttr in generally the same ways it was useful for Identifier. Differential Revision: https://reviews.llvm.org/D113536 --- mlir/include/mlir/CAPI/IR.h | 2 +- mlir/lib/Bindings/Python/IRCore.cpp | 6 ++++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 4 ++-- mlir/lib/CAPI/IR/IR.cpp | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 8366b0bce..7fd47504a 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -30,7 +30,7 @@ DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) DEFINE_C_API_METHODS(MlirAttribute, mlir::Attribute) -DEFINE_C_API_METHODS(MlirIdentifier, mlir::Identifier) +DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cf59a67f9..4c25fd450 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1845,7 +1845,8 @@ class PyOpAttributeMap { mlirOperationGetAttribute(operation->get(), index); return PyNamedAttribute( namedAttr.attribute, - std::string(mlirIdentifierStr(namedAttr.name).data)); + std::string(mlirIdentifierStr(namedAttr.name).data, + mlirIdentifierStr(namedAttr.name).length)); } void dunderSetItem(const std::string &name, PyAttribute attr) { @@ -2601,7 +2602,8 @@ void mlir::python::populateIRCore(py::module &m) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - mlirIdentifierStr(self.namedAttr.name).data); + py::str(mlirIdentifierStr(self.namedAttr.name).data, + mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, printAccum.getCallback(), diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 8d6c4ccf6..7ce428360 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -186,11 +186,11 @@ bool mlirAttributeIsAString(MlirAttribute attr) { } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); + return wrap((Attribute)StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(type))); + return wrap((Attribute)StringAttr::get(unwrap(str), unwrap(type))); } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8bed10a9d..7bbbc4a1d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -805,7 +805,7 @@ MlirOperation mlirSymbolTableLookup(MlirSymbolTable symbolTable, MlirAttribute mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation) { - return wrap(unwrap(symbolTable)->insert(unwrap(operation))); + return wrap((Attribute)unwrap(symbolTable)->insert(unwrap(operation))); } void mlirSymbolTableErase(MlirSymbolTable symbolTable, From 82144f44d485afb57c804e1fd2ffbbc4ead2010f Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 11 Nov 2021 17:33:24 +0000 Subject: [PATCH 159/915] [mlir][Linalg] Add 1-d depthwise conv with opdsl Differential Revision: https://reviews.llvm.org/D113686 --- .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9f5b27ea0..2e6b04118 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -309,6 +309,25 @@ def conv_3d_ndhwc_dhwcf( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c ]) * cast(U, K[D.kd, D.kh, D.kw, D.c, D.f]) +@linalg_structured_op +def depthwise_conv1D_nw( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=AttributeDef(S.SW), + dilations=AttributeDef(S.DW)): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most dpethwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ow, D.ic] += \ + cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + cast(U, K[D.kw, D.ic]) + @linalg_structured_op def depthwise_conv2D_nhw( I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), From b35fc5600e3688a80122a2c077c8cccb097bbd06 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 11 Nov 2021 17:31:39 -0800 Subject: [PATCH 160/915] [mlir] Allow out-of-tree python building from installed MLIR. * Depends on D111504, which provides the boilerplate for building aggregate shared libraries from installed MLIR. * Adds a full-fledged Python example dialect and tests to the Standalone example (need to do a bit of tweaking in the top level CMake and lit tests to adapt better to if not building with Python enabled). * Rips out remnants of custom extension building in favor of `pybind11_add_module` which does the right thing. * Makes python and extension sources installable (outputs to src/python/${name} in the install tree): Both Python and C++ extension sources get installed as downstreams need all of this in order to build a derived version of the API. * Exports sources targets (with our properties that make everything work) by converting them to INTERFACE libraries (which have export support), as recommended for the forseeable future by CMake devs. Renames custom properties to start with lower-case letter, as also recommended/required (groan). * Adds a ROOT_DIR argument to `declare_mlir_python_extension` since now all C++ sources for an extension must be under the same directory (to line up at install time). * Need to validate against a downstream or two and adjust, prior to submitting. Downstreams will need to adapt by: * Remove absolute paths from any SOURCES for `declare_mlir_python_extension` (I believe all downstreams are just using `${CMAKE_CURRENT_SOURCE_DIR}` here, which can just be ommitted). May need to set `ROOT_DIR` if not relative to the current source directory. * To allow further downstreams to install/build, will need to make sure that all C++ extension headers are also listed under SOURCES for `declare_mlir_python_extension`. Reviewed By: stephenneuendorffer, mikeurbach Differential Revision: https://reviews.llvm.org/D111513 --- mlir/python/CMakeLists.txt | 54 ++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 19 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 6a49c773e..644530f0e 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -174,18 +174,26 @@ set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/DialectLinalg.cpp # TODO: Break this out. - ${PYTHON_SOURCE_DIR}/DialectSparseTensor.cpp # TODO: Break this out. - ${PYTHON_SOURCE_DIR}/MainModule.cpp - ${PYTHON_SOURCE_DIR}/IRAffine.cpp - ${PYTHON_SOURCE_DIR}/IRAttributes.cpp - ${PYTHON_SOURCE_DIR}/IRCore.cpp - ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp - ${PYTHON_SOURCE_DIR}/IRModule.cpp - ${PYTHON_SOURCE_DIR}/IRTypes.cpp - ${PYTHON_SOURCE_DIR}/PybindUtils.cpp - ${PYTHON_SOURCE_DIR}/Pass.cpp + DialectLinalg.cpp # TODO: Break this out. + DialectSparseTensor.cpp # TODO: Break this out. + MainModule.cpp + IRAffine.cpp + IRAttributes.cpp + IRCore.cpp + IRInterfaces.cpp + IRModule.cpp + IRTypes.cpp + PybindUtils.cpp + Pass.cpp + + # Headers must be included explicitly so they are installed. + Dialects.h + Globals.h + IRModule.h + Pass.h + PybindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -202,8 +210,9 @@ declare_mlir_python_extension(MLIRPythonExtension.Core declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration MODULE_NAME _mlirAllPassesRegistration + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/AllPassesRegistration.cpp + AllPassesRegistration.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -214,8 +223,9 @@ declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/AsyncPasses.cpp + AsyncPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -225,8 +235,9 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Conversions MODULE_NAME _mlirConversions ADD_TO_PARENT MLIRPythonSources.Passes + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/Conversions/Conversions.cpp + Conversions/Conversions.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -236,8 +247,9 @@ declare_mlir_python_extension(MLIRPythonExtension.Conversions declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/ExecutionEngineModule.cpp + ExecutionEngineModule.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -247,8 +259,9 @@ declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/GPUPasses.cpp + GPUPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -258,8 +271,9 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MODULE_NAME _mlirLinalgPasses ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/LinalgPasses.cpp + LinalgPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -269,8 +283,9 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/SparseTensorPasses.cpp + SparseTensorPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -280,8 +295,9 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Transforms MODULE_NAME _mlirTransforms ADD_TO_PARENT MLIRPythonSources.Passes + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/Transforms/Transforms.cpp + Transforms/Transforms.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS From c6daf51b95f570bc783e4537ca8d1b6da5d3e631 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 12 Nov 2021 02:30:53 +0000 Subject: [PATCH 161/915] Revert "[mlir] Allow out-of-tree python building from installed MLIR." This reverts commit b35fc5600e3688a80122a2c077c8cccb097bbd06. Build is broken (multiple buildbots) --- mlir/python/CMakeLists.txt | 54 ++++++++++++++------------------------ 1 file changed, 19 insertions(+), 35 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 644530f0e..6a49c773e 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -174,26 +174,18 @@ set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - DialectLinalg.cpp # TODO: Break this out. - DialectSparseTensor.cpp # TODO: Break this out. - MainModule.cpp - IRAffine.cpp - IRAttributes.cpp - IRCore.cpp - IRInterfaces.cpp - IRModule.cpp - IRTypes.cpp - PybindUtils.cpp - Pass.cpp - - # Headers must be included explicitly so they are installed. - Dialects.h - Globals.h - IRModule.h - Pass.h - PybindUtils.h + ${PYTHON_SOURCE_DIR}/DialectLinalg.cpp # TODO: Break this out. + ${PYTHON_SOURCE_DIR}/DialectSparseTensor.cpp # TODO: Break this out. + ${PYTHON_SOURCE_DIR}/MainModule.cpp + ${PYTHON_SOURCE_DIR}/IRAffine.cpp + ${PYTHON_SOURCE_DIR}/IRAttributes.cpp + ${PYTHON_SOURCE_DIR}/IRCore.cpp + ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp + ${PYTHON_SOURCE_DIR}/IRModule.cpp + ${PYTHON_SOURCE_DIR}/IRTypes.cpp + ${PYTHON_SOURCE_DIR}/PybindUtils.cpp + ${PYTHON_SOURCE_DIR}/Pass.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -210,9 +202,8 @@ declare_mlir_python_extension(MLIRPythonExtension.Core declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration MODULE_NAME _mlirAllPassesRegistration - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - AllPassesRegistration.cpp + ${PYTHON_SOURCE_DIR}/AllPassesRegistration.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -223,9 +214,8 @@ declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - AsyncPasses.cpp + ${PYTHON_SOURCE_DIR}/AsyncPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -235,9 +225,8 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Conversions MODULE_NAME _mlirConversions ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - Conversions/Conversions.cpp + ${PYTHON_SOURCE_DIR}/Conversions/Conversions.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -247,9 +236,8 @@ declare_mlir_python_extension(MLIRPythonExtension.Conversions declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ExecutionEngineModule.cpp + ${PYTHON_SOURCE_DIR}/ExecutionEngineModule.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -259,9 +247,8 @@ declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - GPUPasses.cpp + ${PYTHON_SOURCE_DIR}/GPUPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -271,9 +258,8 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MODULE_NAME _mlirLinalgPasses ADD_TO_PARENT MLIRPythonSources.Dialects.linalg - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - LinalgPasses.cpp + ${PYTHON_SOURCE_DIR}/LinalgPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -283,9 +269,8 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - SparseTensorPasses.cpp + ${PYTHON_SOURCE_DIR}/SparseTensorPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -295,9 +280,8 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Transforms MODULE_NAME _mlirTransforms ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - Transforms/Transforms.cpp + ${PYTHON_SOURCE_DIR}/Transforms/Transforms.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS From d83ce3537a70534185924048fdec3c17cbffad55 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 11 Nov 2021 21:18:16 -0800 Subject: [PATCH 162/915] [mlir] Add MLIR-C dylib. Per discussion on discord and various feature requests across bindings (Haskell and Rust bindings authors have asked me directly), we should be building a link-ready MLIR-C dylib which exports the C API and can be used without linking to anything else. This patch: * Adds a new MLIR-C aggregate shared library (libMLIR-C.so), which is similar in name and function to libLLVM-C.so. * It is guarded by the new CMake option MLIR_BUILD_MLIR_C_DYLIB, which has a similar purpose/name to the LLVM_BUILD_LLVM_C_DYLIB option. * On all platforms, this will work with both static, BUILD_SHARED_LIBS, and libMLIR builds, if supported: * In static builds: libMLIR-C.so will export the CAPI symbols and statically link all dependencies into itself. * In BUILD_SHARED_LIBS: libMLIR-C.so will export the CAPI symbols and have dynamic dependencies on implementation shared libraries. * In libMLIR.so mode: same as static. libMLIR.so was not finished for actual linking use within the project. An eventual relayering so that libMLIR-C.so depends on libMLIR.so is possible but requires first re-engineering the latter to use the aggregate facility. * On Linux, exported symbols are filtered to only the CAPI. On others (MacOS, Windows), all symbols are exported. A CMake status is printed unless if global visibility is hidden indicating that this has not yet been implemented. The library should still work, but it will be larger and more likely to conflict until fixed. Someone should look at lifting the corresponding support from libLLVM-C.so and adapting. Or, for special uses, just build with `-DCMAKE_CXX_VISIBILITY_PRESET=hidden -DCMAKE_C_VISIBILITY_PRESET=hidden`. * Includes fixes to execution engine symbol export macros to enable default visibility. Without this, the advice to use hidden visibility would have resulted in test failures and unusable execution engine support libraries. Differential Revision: https://reviews.llvm.org/D113731 --- mlir/lib/CAPI/CMakeLists.txt | 27 ++++++++++++++++++++ mlir/lib/CAPI/Conversion/CMakeLists.txt | 2 +- mlir/lib/CAPI/Debug/CMakeLists.txt | 2 +- mlir/lib/CAPI/Dialect/CMakeLists.txt | 18 ++++++------- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 2 +- mlir/lib/CAPI/IR/CMakeLists.txt | 2 +- mlir/lib/CAPI/Interfaces/CMakeLists.txt | 2 +- mlir/lib/CAPI/Registration/CMakeLists.txt | 2 +- mlir/lib/CAPI/Transforms/CMakeLists.txt | 2 +- 9 files changed, 43 insertions(+), 16 deletions(-) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 30ccbe94a..5545de691 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,3 +1,12 @@ +# For upstream, we accumulate all libraries into the MLIR_CAPI_LIBRARIES +# property via a custom wrapper function. This is then used to create an +# aggregate below. +set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBRARIES) +function(add_mlir_upstream_c_api_library name) + add_mlir_public_c_api_library(${name} ${ARGN}) + set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBRARIES ${name}) +endfunction() + add_subdirectory(Debug) add_subdirectory(Dialect) add_subdirectory(Conversion) @@ -6,3 +15,21 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) + +# Build the optional CAPI dylib. +if(MLIR_BUILD_MLIR_C_DYLIB) + message(STATUS "Building MLIR-C dylib") + get_property(_capi_libraries GLOBAL PROPERTY MLIR_CAPI_LIBRARIES) + add_mlir_aggregate(MLIR-C + SHARED + EMBED_LIBS + ${_capi_libraries} + ) + if(CMAKE_SYSTEM_NAME STREQUAL "Linux") + target_link_options(MLIR-C PRIVATE "-Wl,-exclude-libs,ALL") + else() + if(NOT CMAKE_C_VISIBILITY_PRESET STREQUAL "hidden" OR NOT CMAKE_CXX_VISIBILITY_PRESET STREQUAL "hidden") + message(STATUS "MLIR-C on this platform exports all symbols. Recommend building with CMAKE_(C|CXX)_VISIBILITY_PRESET=hidden or implement filtering support.") + endif() + endif() +endif() diff --git a/mlir/lib/CAPI/Conversion/CMakeLists.txt b/mlir/lib/CAPI/Conversion/CMakeLists.txt index 83435cd19..166e79916 100644 --- a/mlir/lib/CAPI/Conversion/CMakeLists.txt +++ b/mlir/lib/CAPI/Conversion/CMakeLists.txt @@ -1,5 +1,5 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_mlir_public_c_api_library(MLIRCAPIConversion +add_mlir_upstream_c_api_library(MLIRCAPIConversion Passes.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/Debug/CMakeLists.txt b/mlir/lib/CAPI/Debug/CMakeLists.txt index fdffe304d..7b32f3ae0 100644 --- a/mlir/lib/CAPI/Debug/CMakeLists.txt +++ b/mlir/lib/CAPI/Debug/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_public_c_api_library(MLIRCAPIDebug +add_mlir_upstream_c_api_library(MLIRCAPIDebug Debug.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 801b0f77a..4f11bc52c 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_public_c_api_library(MLIRCAPIAsync +add_mlir_upstream_c_api_library(MLIRCAPIAsync Async.cpp AsyncPasses.cpp @@ -13,7 +13,7 @@ add_mlir_public_c_api_library(MLIRCAPIAsync MLIRPass ) -add_mlir_public_c_api_library(MLIRCAPIGPU +add_mlir_upstream_c_api_library(MLIRCAPIGPU GPU.cpp GPUPasses.cpp @@ -27,7 +27,7 @@ add_mlir_public_c_api_library(MLIRCAPIGPU MLIRPass ) -add_mlir_public_c_api_library(MLIRCAPILLVM +add_mlir_upstream_c_api_library(MLIRCAPILLVM LLVM.cpp PARTIAL_SOURCES_INTENDED @@ -36,7 +36,7 @@ add_mlir_public_c_api_library(MLIRCAPILLVM MLIRLLVMIR ) -add_mlir_public_c_api_library(MLIRCAPILinalg +add_mlir_upstream_c_api_library(MLIRCAPILinalg Linalg.cpp LinalgPasses.cpp @@ -51,7 +51,7 @@ add_mlir_public_c_api_library(MLIRCAPILinalg MLIRLinalgTransforms ) -add_mlir_public_c_api_library(MLIRCAPISCF +add_mlir_upstream_c_api_library(MLIRCAPISCF SCF.cpp PARTIAL_SOURCES_INTENDED @@ -60,7 +60,7 @@ add_mlir_public_c_api_library(MLIRCAPISCF MLIRSCF ) -add_mlir_public_c_api_library(MLIRCAPIShape +add_mlir_upstream_c_api_library(MLIRCAPIShape Shape.cpp PARTIAL_SOURCES_INTENDED @@ -69,7 +69,7 @@ add_mlir_public_c_api_library(MLIRCAPIShape MLIRShape ) -add_mlir_public_c_api_library(MLIRCAPISparseTensor +add_mlir_upstream_c_api_library(MLIRCAPISparseTensor SparseTensor.cpp SparseTensorPasses.cpp @@ -80,7 +80,7 @@ add_mlir_public_c_api_library(MLIRCAPISparseTensor MLIRSparseTensorTransforms ) -add_mlir_public_c_api_library(MLIRCAPIStandard +add_mlir_upstream_c_api_library(MLIRCAPIStandard Standard.cpp PARTIAL_SOURCES_INTENDED @@ -89,7 +89,7 @@ add_mlir_public_c_api_library(MLIRCAPIStandard MLIRStandard ) -add_mlir_public_c_api_library(MLIRCAPITensor +add_mlir_upstream_c_api_library(MLIRCAPITensor Tensor.cpp PARTIAL_SOURCES_INTENDED diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 09dcb6143..0c28e7ef8 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,5 +1,5 @@ # Main API shared library. -add_mlir_public_c_api_library(MLIRCEXECUTIONENGINE +add_mlir_upstream_c_api_library(MLIRCEXECUTIONENGINE ExecutionEngine.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt index 486ba6e0f..320ed0718 100644 --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -1,5 +1,5 @@ # Main API shared library. -add_mlir_public_c_api_library(MLIRCAPIIR +add_mlir_upstream_c_api_library(MLIRCAPIIR AffineExpr.cpp AffineMap.cpp BuiltinAttributes.cpp diff --git a/mlir/lib/CAPI/Interfaces/CMakeLists.txt b/mlir/lib/CAPI/Interfaces/CMakeLists.txt index 1de5f21d8..7aefb56d9 100644 --- a/mlir/lib/CAPI/Interfaces/CMakeLists.txt +++ b/mlir/lib/CAPI/Interfaces/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_public_c_api_library(MLIRCAPIInterfaces +add_mlir_upstream_c_api_library(MLIRCAPIInterfaces Interfaces.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/Registration/CMakeLists.txt b/mlir/lib/CAPI/Registration/CMakeLists.txt index b4a8650b3..67e26a50f 100644 --- a/mlir/lib/CAPI/Registration/CMakeLists.txt +++ b/mlir/lib/CAPI/Registration/CMakeLists.txt @@ -2,7 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_mlir_public_c_api_library(MLIRCAPIRegistration +add_mlir_upstream_c_api_library(MLIRCAPIRegistration Registration.cpp LINK_LIBS PUBLIC diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt index e5e1677ec..2638025a8 100644 --- a/mlir/lib/CAPI/Transforms/CMakeLists.txt +++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt @@ -1,4 +1,4 @@ -add_mlir_public_c_api_library(MLIRCAPITransforms +add_mlir_upstream_c_api_library(MLIRCAPITransforms Passes.cpp LINK_LIBS PUBLIC From e53968cfae8e10577c7c5206cba579650b94d1f0 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 14 Nov 2021 14:44:25 -0800 Subject: [PATCH 163/915] Re-apply "[mlir] Allow out-of-tree python building from installed MLIR." Re-applies D111513: * Adds a full-fledged Python example dialect and tests to the Standalone example (need to do a bit of tweaking in the top level CMake and lit tests to adapt better to if not building with Python enabled). * Rips out remnants of custom extension building in favor of pybind11_add_module which does the right thing. * Makes python and extension sources installable (outputs to src/python/${name} in the install tree): Both Python and C++ extension sources get installed as downstreams need all of this in order to build a derived version of the API. * Exports sources targets (with our properties that make everything work) by converting them to INTERFACE libraries (which have export support), as recommended for the forseeable future by CMake devs. Renames custom properties to start with lower-case letter, as also recommended/required (groan). * Adds a ROOT_DIR argument to declare_mlir_python_extension since now all C++ sources for an extension must be under the same directory (to line up at install time). * Downstreams will need to adapt by: * Remove absolute paths from any SOURCES for declare_mlir_python_extension (I believe all downstreams are just using ${CMAKE_CURRENT_SOURCE_DIR} here, which can just be ommitted). May need to set ROOT_DIR if not relative to the current source directory. * To allow further downstreams to install/build, will need to make sure that all C++ extension headers are also listed under SOURCES for declare_mlir_python_extension. This reverts commit c6daf51b95f570bc783e4537ca8d1b6da5d3e631. Reviewed By: stephenneuendorffer Differential Revision: https://reviews.llvm.org/D113732 --- mlir/python/CMakeLists.txt | 57 +++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 6a49c773e..c2b9d753a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -174,18 +174,26 @@ set(PYTHON_SOURCE_DIR "${MLIR_SOURCE_DIR}/lib/Bindings/Python") declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/DialectLinalg.cpp # TODO: Break this out. - ${PYTHON_SOURCE_DIR}/DialectSparseTensor.cpp # TODO: Break this out. - ${PYTHON_SOURCE_DIR}/MainModule.cpp - ${PYTHON_SOURCE_DIR}/IRAffine.cpp - ${PYTHON_SOURCE_DIR}/IRAttributes.cpp - ${PYTHON_SOURCE_DIR}/IRCore.cpp - ${PYTHON_SOURCE_DIR}/IRInterfaces.cpp - ${PYTHON_SOURCE_DIR}/IRModule.cpp - ${PYTHON_SOURCE_DIR}/IRTypes.cpp - ${PYTHON_SOURCE_DIR}/PybindUtils.cpp - ${PYTHON_SOURCE_DIR}/Pass.cpp + DialectLinalg.cpp # TODO: Break this out. + DialectSparseTensor.cpp # TODO: Break this out. + MainModule.cpp + IRAffine.cpp + IRAttributes.cpp + IRCore.cpp + IRInterfaces.cpp + IRModule.cpp + IRTypes.cpp + PybindUtils.cpp + Pass.cpp + + # Headers must be included explicitly so they are installed. + Dialects.h + Globals.h + IRModule.h + Pass.h + PybindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -202,8 +210,9 @@ declare_mlir_python_extension(MLIRPythonExtension.Core declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration MODULE_NAME _mlirAllPassesRegistration + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/AllPassesRegistration.cpp + AllPassesRegistration.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -214,8 +223,9 @@ declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/AsyncPasses.cpp + AsyncPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -225,8 +235,9 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Conversions MODULE_NAME _mlirConversions ADD_TO_PARENT MLIRPythonSources.Passes + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/Conversions/Conversions.cpp + Conversions/Conversions.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -236,8 +247,9 @@ declare_mlir_python_extension(MLIRPythonExtension.Conversions declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/ExecutionEngineModule.cpp + ExecutionEngineModule.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -247,8 +259,9 @@ declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/GPUPasses.cpp + GPUPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -258,8 +271,9 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MODULE_NAME _mlirLinalgPasses ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/LinalgPasses.cpp + LinalgPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -269,8 +283,9 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/SparseTensorPasses.cpp + SparseTensorPasses.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -280,8 +295,9 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses declare_mlir_python_extension(MLIRPythonExtension.Transforms MODULE_NAME _mlirTransforms ADD_TO_PARENT MLIRPythonSources.Passes + ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - ${PYTHON_SOURCE_DIR}/Transforms/Transforms.cpp + Transforms/Transforms.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS @@ -322,8 +338,9 @@ if(MLIR_INCLUDE_TESTS) declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension MODULE_NAME _mlirPythonTest ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" SOURCES - ${MLIR_SOURCE_DIR}/test/python/lib/PythonTestModule.cpp + PythonTestModule.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS From ac4a032c63f8cdc3b8910216462da5399418e143 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 15 Nov 2021 07:54:07 +0000 Subject: [PATCH 164/915] [mlir][Linalg] Make depthwise convolution naming scheme consistent. Names should be consistent across all operations otherwise painful bugs will surface. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D113762 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 2e6b04118..85bed25fe 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -310,7 +310,7 @@ def conv_3d_ndhwc_dhwcf( ]) * cast(U, K[D.kd, D.kh, D.kw, D.c, D.f]) @linalg_structured_op -def depthwise_conv1D_nw( +def depthwise_conv_1d_nwc_wc( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KW, S.IC), O=TensorDef(U, S.N, S.OW, S.IC, output=True), @@ -320,7 +320,7 @@ def depthwise_conv1D_nw( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most dpethwise convolutions. + which is a special case for most depthwise convolutions. """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.ic, D.kw) @@ -329,7 +329,7 @@ def depthwise_conv1D_nw( cast(U, K[D.kw, D.ic]) @linalg_structured_op -def depthwise_conv2D_nhw( +def depthwise_conv_2d_nhwc_hwc( I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), @@ -339,7 +339,7 @@ def depthwise_conv2D_nhw( Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most dpethwise convolutions. + which is a special case for most depthwise convolutions. """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) @@ -348,7 +348,7 @@ def depthwise_conv2D_nhw( D.ic]) * cast(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op -def depthwise_conv2D_nhw_q( +def depthwise_conv_2d_nhwc_hwc_q( I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), IZp=ScalarDef(I32), @@ -369,7 +369,7 @@ def depthwise_conv2D_nhw_q( (cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp))) @linalg_structured_op -def depthwise_conv2D_nhwc( +def depthwise_conv_2d_nhwc_hwcm( I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), @@ -387,7 +387,7 @@ def depthwise_conv2D_nhwc( D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv2D_nhwc_q( +def depthwise_conv_2d_nhwc_hwcm_q( I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), IZp=ScalarDef(I32), From c2af7d32b1b26257bbeae7741b49c6641b8b41ab Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Mon, 15 Nov 2021 12:52:37 +0100 Subject: [PATCH 165/915] [mlir] Move min/max ops from Std to Arith. Differential Revision: https://reviews.llvm.org/D113881 --- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 2ece9eb92..933c26ad9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -334,30 +334,30 @@ def _eval_mul(self, lhs: Value, rhs: Value) -> Value: def _eval_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs, rhs).result + return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxSIOp(lhs, rhs).result + return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MaxFOp(lhs, rhs).result + return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MaxUIOp(lhs, rhs).result + return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") def _eval_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs, rhs).result + return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinSIOp(lhs, rhs).result + return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return std.MinFOp(lhs, rhs).result + return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return std.MinUIOp(lhs, rhs).result + return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") From 1dd87758bd0df02dbfefa9d57f886a1e06d3ff21 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 16 Nov 2021 17:21:15 +0000 Subject: [PATCH 166/915] [mlir][NFC] Replace references to Identifier with StringAttr This is part of the replacement of Identifier with StringAttr. Differential Revision: https://reviews.llvm.org/D113953 --- mlir/include/mlir/CAPI/IR.h | 1 - mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 2 +- mlir/lib/CAPI/IR/IR.cpp | 6 +++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 7fd47504a..af7ae8977 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -17,7 +17,6 @@ #include "mlir/CAPI/Wrap.h" #include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Identifier.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 7ce428360..5ec7e3468 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -165,7 +165,7 @@ MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { return wrap( - OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + OpaqueAttr::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), StringRef(data, dataLength), unwrap(type))); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 7bbbc4a1d..a339aef90 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -145,9 +145,9 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc) { if (mlirLocationIsNull(childLoc)) return wrap( - Location(NameLoc::get(Identifier::get(unwrap(name), unwrap(context))))); + Location(NameLoc::get(StringAttr::get(unwrap(context), unwrap(name))))); return wrap(Location(NameLoc::get( - Identifier::get(unwrap(name), unwrap(context)), unwrap(childLoc)))); + StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); } MlirLocation mlirLocationUnknownGet(MlirContext context) { @@ -753,7 +753,7 @@ MlirNamedAttribute mlirNamedAttributeGet(MlirIdentifier name, //===----------------------------------------------------------------------===// MlirIdentifier mlirIdentifierGet(MlirContext context, MlirStringRef str) { - return wrap(Identifier::get(unwrap(str), unwrap(context))); + return wrap(StringAttr::get(unwrap(context), unwrap(str))); } MlirContext mlirIdentifierGetContext(MlirIdentifier ident) { From 295b9b800911e67ea8c0d3c8c106037a1ccd4026 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 16 Nov 2021 17:09:08 +0100 Subject: [PATCH 167/915] [mlir] Fix wrong variable name in Linalg OpDSL The name seems to have been left over from a renaming effort on an unexercised codepaths that are difficult to catch in Python. Fix it and add a test that exercises the codepath. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D114004 --- mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 1acae7a7a..a65350ccd 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -112,7 +112,7 @@ def linalg_structured_op(dsl_func=None, if dsl_func is None: # Curry the keyword args in for delayed application. return functools.partial( - tc_def_op, op_name=op_name, op_class_name=op_class_name) + linalg_structured_op, op_name=op_name, op_class_name=op_class_name) # Determine default names by introspecting the function. if op_name is None: op_name = dsl_func.__name__ @@ -131,9 +131,10 @@ def linalg_structured_op(dsl_func=None, if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)): tc_model.add_operand(param_name, param_default.operand_def) else: - raise ValueError(f"@tc_def_op function parameters must be defaulted as " - f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " - f"Found {param_name}: {param_default}") + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " + f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) # Invoke the DSL func to finish populating the model. From 0d256f05fc9ff3f6071583bd34471f21cd7dd21a Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 17 Nov 2021 21:50:28 +0000 Subject: [PATCH 168/915] [mlir] Refactor AbstractOperation and OperationName The current implementation is quite clunky; OperationName stores either an Identifier or an AbstractOperation that corresponds to an operation. This has several problems: * OperationNames created before and after an operation are registered are different * Accessing the identifier name/dialect/etc. from an OperationName are overly branchy - they need to dyn_cast a PointerUnion to check the state This commit refactors this such that we create a single information struct for every operation name, even operations that aren't registered yet. When an OperationName is created for an unregistered operation, we only populate the name field. When the operation is registered, we populate the remaining fields. With this we now have two new classes: OperationName and RegisteredOperationName. These both point to the same underlying operation information struct, but only RegisteredOperationName can assume that the operation is actually registered. This leads to a much cleaner API, and we can also move some AbstractOperation functionality directly to OperationName. Differential Revision: https://reviews.llvm.org/D114049 --- mlir/lib/CAPI/IR/IR.cpp | 12 +++++------- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 18 +++++++++--------- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index a339aef90..11b0157ae 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -264,9 +264,8 @@ void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { static LogicalResult inferOperationTypes(OperationState &state) { MLIRContext *context = state.getContext(); - const AbstractOperation *abstractOp = - AbstractOperation::lookup(state.name.getStringRef(), context); - if (!abstractOp) { + Optional info = state.name.getRegisteredInfo(); + if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name << ", but the operation was not registered. Ensure that the dialect " @@ -276,7 +275,7 @@ static LogicalResult inferOperationTypes(OperationState &state) { } // Fallback to inference via an op interface. - auto *inferInterface = abstractOp->getInterface(); + auto *inferInterface = info->getInterface(); if (!inferInterface) { emitError(state.location) << "type inference was requested for the operation " << state.name @@ -353,9 +352,8 @@ MlirLocation mlirOperationGetLocation(MlirOperation op) { } MlirTypeID mlirOperationGetTypeID(MlirOperation op) { - if (const auto *abstractOp = unwrap(op)->getAbstractOperation()) { - return wrap(abstractOp->typeID); - } + if (auto info = unwrap(op)->getRegisteredInfo()) + return wrap(info->getTypeID()); return {nullptr}; } diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 315adb5fb..f752a57b5 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -17,17 +17,17 @@ using namespace mlir; bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID) { - const AbstractOperation *abstractOp = - unwrap(operation)->getAbstractOperation(); - return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); + Optional info = + unwrap(operation)->getRegisteredInfo(); + return info && info->hasInterface(unwrap(interfaceTypeID)); } bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID) { - const AbstractOperation *abstractOp = AbstractOperation::lookup( + Optional info = RegisteredOperationName::lookup( StringRef(operationName.data, operationName.length), unwrap(context)); - return abstractOp && abstractOp->hasInterface(unwrap(interfaceTypeID)); + return info && info->hasInterface(unwrap(interfaceTypeID)); } MlirTypeID mlirInferTypeOpInterfaceTypeID() { @@ -40,9 +40,9 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData) { StringRef name(opName.data, opName.length); - const AbstractOperation *abstractOp = - AbstractOperation::lookup(name, unwrap(context)); - if (!abstractOp) + Optional info = + RegisteredOperationName::lookup(name, unwrap(context)); + if (!info) return mlirLogicalResultFailure(); llvm::Optional maybeLocation = llvm::None; @@ -68,7 +68,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( }); SmallVector inferredTypes; - if (failed(abstractOp->getInterface()->inferReturnTypes( + if (failed(info->getInterface()->inferReturnTypes( unwrap(context), maybeLocation, unwrappedOperands, attributeDict, unwrappedRegions, inferredTypes))) return mlirLogicalResultFailure(); From d17e7690084d2872fd1dc99d2663531b27651a88 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 18 Nov 2021 05:23:32 +0000 Subject: [PATCH 169/915] [mlir] Convert NamedAttribute to be a class NamedAttribute is currently represented as an std::pair, but this creates an extremely clunky .first/.second API. This commit converts it to a class, with better accessors (getName/getValue) and also opens the door for more convenient API in the future. Differential Revision: https://reviews.llvm.org/D113956 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 2 +- mlir/lib/CAPI/IR/IR.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 5ec7e3468..c20548bd4 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -83,7 +83,7 @@ MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = unwrap(attr).cast().getValue()[pos]; - return {wrap(attribute.first), wrap(attribute.second)}; + return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 11b0157ae..35a059275 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -432,7 +432,7 @@ intptr_t mlirOperationGetNumAttributes(MlirOperation op) { MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos) { NamedAttribute attr = unwrap(op)->getAttrs()[pos]; - return MlirNamedAttribute{wrap(attr.first), wrap(attr.second)}; + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; } MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, From 0e72881bb2621f70d4a14d831b5766806b60b531 Mon Sep 17 00:00:00 2001 From: Michal Terepeta Date: Thu, 18 Nov 2021 09:41:57 +0100 Subject: [PATCH 170/915] [mlir][Python] Fix generation of accessors for Optional Previously, in case there was only one `Optional` operand/result within the list, we would always return `None` from the accessor, e.g., for a single optional result we would generate: ``` return self.operation.results[0] if len(self.operation.results) > 1 else None ``` But what we really want is to return `None` only if the length of `results` is smaller than the total number of element groups (i.e., the optional operand/result is in fact missing). This commit also renames a few local variables in the generator to make the distinction between `isVariadic()` and `isVariableLength()` a bit more clear. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D113855 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index b7641c0a4..d6c57547e 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -36,15 +36,6 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None): OpView.__init__(self, op) linalgDialect = Context.current.get_dialect_descriptor("linalg") fill_builtin_region(linalgDialect, self.operation) - # TODO: self.result is None. When len(results) == 1 we expect it to be - # results[0] as per _linalg_ops_gen.py. This seems like an orthogonal bug - # in the generator of _linalg_ops_gen.py where we have: - # ``` - # def result(self): - # return self.operation.results[0] \ - # if len(self.operation.results) > 1 else None - # ``` - class InitTensorOp: """Extends the linalg.init_tensor op.""" From 160c865bc3f9ce7bc0a90b7f6c655fe235e00f10 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 22 Nov 2021 10:57:33 +0000 Subject: [PATCH 171/915] [mlir] Add InitializeNativeTargetAsmParser to ExecutionEngine. This is required to allow python to work with lowerings that use inline_asm. Differential Revision: https://reviews.llvm.org/D114338 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 42bacd967..a9bb09e61 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -22,6 +22,7 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths) { static bool initOnce = [] { llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm llvm::InitializeNativeTargetAsmPrinter(); return true; }(); From 7dcc64c05aba91131e642e0084f90a2276249045 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 22 Nov 2021 10:37:42 +0100 Subject: [PATCH 172/915] Rename MlirExecutionEngine lookup to lookupPacked The purpose of the change is to make clear whether the user is retrieving the original function or the wrapper function, in line with the invoke commands. This new functionality is useful for users that already have defined their own packed interface, so they do not want the extra layer of indirection, or for users wanting to the look at the resulting primary function rather than the wrapper function. All locations, except the python bindings now have a `lookupPacked` method that matches the original `lookup` functionality. `lookup` still exists, but with new semantics. - `lookup` returns the function with a given name. If `bool f(int,int)` is compiled, `lookup` will return a reference to `bool(*f)(int,int)`. - `lookupPacked` returns the packed wrapper of the function with the given name. If `bool f(int,int)` is compiled, `lookupPacked` will return `void(*mlir_f)(void**)`. Differential Revision: https://reviews.llvm.org/D114352 --- mlir/include/mlir-c/ExecutionEngine.h | 5 +++++ mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 2 +- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 8 ++++++++ 3 files changed, 14 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index bb454529b..cd3df8ebf 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -62,6 +62,11 @@ static inline bool mlirExecutionEngineIsNull(MlirExecutionEngine jit) { MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked( MlirExecutionEngine jit, MlirStringRef name, void **arguments); +/// Lookup the wrapper of the native function in the execution engine with the +/// given name, returns nullptr if the function can't be looked-up. +MLIR_CAPI_EXPORTED void * +mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, MlirStringRef name); + /// Lookup a native function in the execution engine by name, returns nullptr /// if the name can't be looked-up. MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit, diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 07c35163c..c49d9903f 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -100,7 +100,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { .def( "raw_lookup", [](PyExecutionEngine &executionEngine, const std::string &func) { - auto *res = mlirExecutionEngineLookup( + auto *res = mlirExecutionEngineLookupPacked( executionEngine.get(), mlirStringRefCreate(func.c_str(), func.size())); return reinterpret_cast(res); diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index a9bb09e61..604cc4522 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -75,6 +75,14 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name, return wrap(success()); } +extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, + MlirStringRef name) { + auto expectedFPtr = unwrap(jit)->lookupPacked(unwrap(name)); + if (!expectedFPtr) + return nullptr; + return reinterpret_cast(*expectedFPtr); +} + extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name) { auto expectedFPtr = unwrap(jit)->lookup(unwrap(name)); From a29a399c22eb461e7887c32e643101f1abdd9e11 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Thu, 25 Nov 2021 20:38:32 +0530 Subject: [PATCH 173/915] [MLIR] NFC. Rename MLIR CAPI ExecutionEngine target for consistency Rename MLIR CAPI ExecutionEngine target for consistency: MLIRCEXECUTIONENGINE -> MLIRCAPIExecutionEngine in line with other targets. Differential Revision: https://reviews.llvm.org/D114596 --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 2 +- mlir/python/CMakeLists.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 0c28e7ef8..48f45f8c0 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,5 +1,5 @@ # Main API shared library. -add_mlir_upstream_c_api_library(MLIRCEXECUTIONENGINE +add_mlir_upstream_c_api_library(MLIRCAPIExecutionEngine ExecutionEngine.cpp LINK_LIBS PUBLIC diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c2b9d753a..cdb2e5a3c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -253,7 +253,7 @@ declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS - MLIRCEXECUTIONENGINE + MLIRCAPIExecutionEngine ) declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses From a58a0a14c65725f8783732180dcfd374d86a0053 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 28 Nov 2021 15:33:03 -0800 Subject: [PATCH 174/915] [mlir][python] Normalize asm-printing IR behavior. While working on an integration, I found a lot of inconsistencies on IR printing and verification. It turns out that we were: * Only doing "soft fail" verification on IR printing of Operation, not of a Module. * Failed verification was interacting badly with binary=True IR printing (causing a TypeError trying to pass an `str` to a `bytes` based handle). * For systematic integrations, it is often desirable to control verification yourself so that you can explicitly handle errors. This patch: * Trues up the "soft fail" semantics by having `Module.__str__` delegate to `Operation.__str__` vs having a shortcut implementation. * Fixes soft fail in the presence of binary=True (and adds an additional happy path test case to make sure the binary functionality works). * Adds an `assume_verified` boolean flag to the `print`/`get_asm` methods which disables internal verification, presupposing that the caller has taken care of it. It turns out that we had a number of tests which were generating illegal IR but it wasn't being caught because they were doing a print on the `Module` vs operation. All except two were trivially fixed: * linalg/ops.py : Had two tests for direct constructing a Matmul incorrectly. Fixing them made them just like the next two tests so just deleted (no need to test the verifier only at this level). * linalg/opdsl/emit_structured_generic.py : Hand coded conv and pooling tests appear to be using illegal shaped inputs/outputs, causing a verification failure. I just used the `assume_verified=` flag to restore the original behavior and left a TODO. Will get someone who owns that to fix it properly in a followup (would also be nice to break this file up into multiple test modules as it is hard to tell exactly what is failing). Notes to downstreams: * If, like some of our tests, you get verification failures after this patch, it is likely that your IR was always invalid and you will need to fix the root cause. To temporarily revert to prior (broken) behavior, replace calls like `print(module)` with `print(module.operation.get_asm(assume_verified=True))`. Differential Revision: https://reviews.llvm.org/D114680 --- mlir/lib/Bindings/Python/IRCore.cpp | 45 +++++++++++++++++++---------- mlir/lib/Bindings/Python/IRModule.h | 6 ++-- 2 files changed, 34 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4c25fd450..c70cfc565 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -93,6 +93,13 @@ static const char kOperationPrintDocstring[] = use_local_Scope: Whether to print in a way that is more optimized for multi-threaded access but may not be consistent with how the overall module prints. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. )"; static const char kOperationGetAsmDocstring[] = @@ -828,14 +835,21 @@ void PyOperation::checkValid() const { void PyOperationBase::print(py::object fileObject, bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope) { + bool printGenericOpForm, bool useLocalScope, + bool assumeVerified) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); - if (!printGenericOpForm && !mlirOperationVerify(operation)) { - fileObject.attr("write")("// Verification failed, printing generic form\n"); + if (!assumeVerified && !printGenericOpForm && + !mlirOperationVerify(operation)) { + std::string message("// Verification failed, printing generic form\n"); + if (binary) { + fileObject.attr("write")(py::bytes(message)); + } else { + fileObject.attr("write")(py::str(message)); + } printGenericOpForm = true; } @@ -857,8 +871,8 @@ void PyOperationBase::print(py::object fileObject, bool binary, py::object PyOperationBase::getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, - bool useLocalScope) { + bool printGenericOpForm, bool useLocalScope, + bool assumeVerified) { py::object fileObject; if (binary) { fileObject = py::module::import("io").attr("BytesIO")(); @@ -870,7 +884,8 @@ py::object PyOperationBase::getAsm(bool binary, /*enableDebugInfo=*/enableDebugInfo, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, - /*useLocalScope=*/useLocalScope); + /*useLocalScope=*/useLocalScope, + /*assumeVerified=*/assumeVerified); return fileObject.attr("getvalue")(); } @@ -2149,12 +2164,9 @@ void mlir::python::populateIRCore(py::module &m) { kDumpDocstring) .def( "__str__", - [](PyModule &self) { - MlirOperation operation = mlirModuleGetOperation(self.get()); - PyPrintAccumulator printAccum; - mlirOperationPrint(operation, printAccum.getCallback(), - printAccum.getUserData()); - return printAccum.join(); + [](py::object self) { + // Defer to the operation's __str__. + return self.attr("operation").attr("__str__")(); }, kOperationStrDunderDocstring); @@ -2234,7 +2246,8 @@ void mlir::python::populateIRCore(py::module &m) { /*enableDebugInfo=*/false, /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, - /*useLocalScope=*/false); + /*useLocalScope=*/false, + /*assumeVerified=*/false); }, "Returns the assembly form of the operation.") .def("print", &PyOperationBase::print, @@ -2244,7 +2257,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("enable_debug_info") = false, py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, kOperationPrintDocstring) + py::arg("use_local_scope") = false, + py::arg("assume_verified") = false, kOperationPrintDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, @@ -2252,7 +2266,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("enable_debug_info") = false, py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, kOperationGetAsmDocstring) + py::arg("use_local_scope") = false, + py::arg("assume_verified") = false, kOperationGetAsmDocstring) .def( "verify", [](PyOperationBase &self) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index eb5c2385a..dc024a247 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -394,11 +394,13 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(pybind11::object fileObject, bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, - bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope); + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, + bool assumeVerified); pybind11::object getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope); + bool printGenericOpForm, bool useLocalScope, + bool assumeVerified); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); From ce4ff32efd8359ee1ca9dc0f2bb6c6ab68d32d84 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 28 Nov 2021 14:08:06 -0800 Subject: [PATCH 175/915] [mlir][python] Add pyi stub files to enable auto completion. There is no completely automated facility for generating stubs that are both accurate and comprehensive for native modules. After some experimentation, I found that MyPy's stubgen does the best at generating correct stubs with a few caveats that are relatively easy to fix: * Some types resolve to cross module symbols incorrectly. * staticmethod and classmethod signatures seem to always be completely generic and need to be manually provided. * It does not generate an __all__ which, from testing, causes namespace pollution to be visible to IDE code completion. As a first step, I did the following: * Ran `stubgen` for `_mlir.ir`, `_mlir.passmanager`, and `_mlirExecutionEngine`. * Manually looked for all instances where unnamed arguments were being emitted (i.e. as 'arg0', etc) and updated the C++ side to include names (and re-ran stubgen to get a good initial state). * Made/noted a few structural changes to each `pyi` file to make it minimally functional. * Added the `pyi` files to the CMake rules so they are installed and visible. To test, I added a `.env` file to the root of the project with `PYTHONPATH=...` set as per instructions. Then reload the developer window (in VsCode) and verify that completion works for various changes to test cases. There are still a number of overly generic signatures, but I want to check in this low-touch baseline before iterating on more ambiguous changes. This is already a big improvement. Differential Revision: https://reviews.llvm.org/D114679 --- .../Bindings/Python/ExecutionEngineModule.cpp | 3 +- mlir/lib/Bindings/Python/IRAffine.cpp | 145 +-- mlir/lib/Bindings/Python/IRAttributes.cpp | 2 +- mlir/lib/Bindings/Python/IRCore.cpp | 54 +- mlir/lib/Bindings/Python/IRModule.h | 30 +- mlir/lib/Bindings/Python/IRTypes.cpp | 6 +- mlir/lib/Bindings/Python/MainModule.cpp | 20 +- mlir/lib/Bindings/Python/Pass.cpp | 3 +- mlir/python/CMakeLists.txt | 6 + .../python/mlir/_mlir_libs/_mlir/__init__.pyi | 13 + mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 863 ++++++++++++++++++ .../mlir/_mlir_libs/_mlir/passmanager.pyi | 24 + .../mlir/_mlir_libs/_mlirExecutionEngine.pyi | 23 + 13 files changed, 1084 insertions(+), 108 deletions(-) create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/ir.pyi create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi create mode 100644 mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index c49d9903f..b5a0f84d4 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -105,6 +105,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { mlirStringRefCreate(func.c_str(), func.size())); return reinterpret_cast(res); }, + py::arg("func_name"), "Lookup function `func` in the ExecutionEngine.") .def( "raw_register_runtime", @@ -127,5 +128,5 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { executionEngine.get(), mlirStringRefCreate(fileName.c_str(), fileName.size())); }, - "Dump ExecutionEngine to an object file."); + py::arg("file_name"), "Dump ExecutionEngine to an object file."); } diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index da80cda9c..272de0d7a 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -98,10 +98,13 @@ class PyConcreteAffineExpr : public BaseTy { static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init()); - cls.def_static("isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { - return DerivedTy::isaFunction(otherAffineExpr); - }); + cls.def(py::init(), py::arg("expr")); + cls.def_static( + "isinstance", + [](PyAffineExpr &otherAffineExpr) -> bool { + return DerivedTy::isaFunction(otherAffineExpr); + }, + py::arg("other")); DerivedTy::bindDerived(cls); } @@ -748,41 +751,50 @@ void mlir::python::populateIRAffine(py::module &m) { }, py::arg("permutation"), py::arg("context") = py::none(), "Gets an affine map that permutes its inputs.") - .def("get_submap", - [](PyAffineMap &self, std::vector &resultPos) { - intptr_t numResults = mlirAffineMapGetNumResults(self); - for (intptr_t pos : resultPos) { - if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); - } - MlirAffineMap affineMap = mlirAffineMapGetSubMap( - self, resultPos.size(), resultPos.data()); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_major_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMajorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("get_minor_submap", - [](PyAffineMap &self, intptr_t nResults) { - if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); - MlirAffineMap affineMap = - mlirAffineMapGetMinorSubMap(self, nResults); - return PyAffineMap(self.getContext(), affineMap); - }) - .def("replace", - [](PyAffineMap &self, PyAffineExpr &expression, - PyAffineExpr &replacement, intptr_t numResultDims, - intptr_t numResultSyms) { - MlirAffineMap affineMap = mlirAffineMapReplace( - self, expression, replacement, numResultDims, numResultSyms); - return PyAffineMap(self.getContext(), affineMap); - }) + .def( + "get_submap", + [](PyAffineMap &self, std::vector &resultPos) { + intptr_t numResults = mlirAffineMapGetNumResults(self); + for (intptr_t pos : resultPos) { + if (pos < 0 || pos >= numResults) + throw py::value_error("result position out of bounds"); + } + MlirAffineMap affineMap = mlirAffineMapGetSubMap( + self, resultPos.size(), resultPos.data()); + return PyAffineMap(self.getContext(), affineMap); + }, + py::arg("result_positions")) + .def( + "get_major_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMajorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }, + py::arg("n_results")) + .def( + "get_minor_submap", + [](PyAffineMap &self, intptr_t nResults) { + if (nResults >= mlirAffineMapGetNumResults(self)) + throw py::value_error("number of results out of bounds"); + MlirAffineMap affineMap = + mlirAffineMapGetMinorSubMap(self, nResults); + return PyAffineMap(self.getContext(), affineMap); + }, + py::arg("n_results")) + .def( + "replace", + [](PyAffineMap &self, PyAffineExpr &expression, + PyAffineExpr &replacement, intptr_t numResultDims, + intptr_t numResultSyms) { + MlirAffineMap affineMap = mlirAffineMapReplace( + self, expression, replacement, numResultDims, numResultSyms); + return PyAffineMap(self.getContext(), affineMap); + }, + py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), + py::arg("n_result_syms")) .def_property_readonly( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) @@ -876,32 +888,35 @@ void mlir::python::populateIRAffine(py::module &m) { }, py::arg("num_dims"), py::arg("num_symbols"), py::arg("context") = py::none()) - .def("get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, - intptr_t numResultDims, intptr_t numResultSymbols) { - if (static_cast(dimExprs.size()) != - mlirIntegerSetGetNumDims(self)) - throw py::value_error( - "Expected the number of dimension replacement expressions " - "to match that of dimensions"); - if (static_cast(symbolExprs.size()) != - mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( - "Expected the number of symbol replacement expressions " - "to match that of symbols"); - - SmallVector dimAffineExprs, symbolAffineExprs; - pyListToVector( - dimExprs, dimAffineExprs, - "attempting to create an IntegerSet by replacing dimensions"); - pyListToVector( - symbolExprs, symbolAffineExprs, - "attempting to create an IntegerSet by replacing symbols"); - MlirIntegerSet set = mlirIntegerSetReplaceGet( - self, dimAffineExprs.data(), symbolAffineExprs.data(), - numResultDims, numResultSymbols); - return PyIntegerSet(self.getContext(), set); - }) + .def( + "get_replaced", + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + intptr_t numResultDims, intptr_t numResultSymbols) { + if (static_cast(dimExprs.size()) != + mlirIntegerSetGetNumDims(self)) + throw py::value_error( + "Expected the number of dimension replacement expressions " + "to match that of dimensions"); + if (static_cast(symbolExprs.size()) != + mlirIntegerSetGetNumSymbols(self)) + throw py::value_error( + "Expected the number of symbol replacement expressions " + "to match that of symbols"); + + SmallVector dimAffineExprs, symbolAffineExprs; + pyListToVector( + dimExprs, dimAffineExprs, + "attempting to create an IntegerSet by replacing dimensions"); + pyListToVector( + symbolExprs, symbolAffineExprs, + "attempting to create an IntegerSet by replacing symbols"); + MlirIntegerSet set = mlirIntegerSetReplaceGet( + self, dimAffineExprs.data(), symbolAffineExprs.data(), + numResultDims, numResultSymbols); + return PyIntegerSet(self.getContext(), set); + }, + py::arg("dim_exprs"), py::arg("symbol_exprs"), + py::arg("num_result_dims"), py::arg("num_result_symbols")) .def_property_readonly("is_canonical_empty", [](PyIntegerSet &self) { return mlirIntegerSetIsCanonicalEmpty(self); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 7db90ec8d..17b3b34a2 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -337,7 +337,7 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); }, - + py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); c.def_property_readonly( "value", diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c70cfc565..8a110fcc4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1631,10 +1631,13 @@ class PyConcreteValue : public PyValue { /// Binds the Python module objects to functions of this class. static void bind(py::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }); + cls.def(py::init(), py::keep_alive<0, 1>(), py::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + py::arg("other_value")); DerivedTy::bindDerived(cls); } @@ -1657,9 +1660,12 @@ class PyBlockArgument : public PyConcreteValue { c.def_property_readonly("arg_number", [](PyBlockArgument &self) { return mlirBlockArgumentGetArgNumber(self.get()); }); - c.def("set_type", [](PyBlockArgument &self, PyType type) { - return mlirBlockArgumentSetType(self.get(), type); - }); + c.def( + "set_type", + [](PyBlockArgument &self, PyType type) { + return mlirBlockArgumentSetType(self.get(), type); + }, + py::arg("type")); } }; @@ -1952,6 +1958,7 @@ void mlir::python::populateIRCore(py::module &m) { } return PyDialectDescriptor(self.getRef(), dialect); }, + py::arg("dialect_name"), "Gets or loads a dialect by name, returning its descriptor object") .def_property( "allow_unregistered_dialects", @@ -1961,15 +1968,19 @@ void mlir::python::populateIRCore(py::module &m) { [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) - .def("enable_multithreading", - [](PyMlirContext &self, bool enable) { - mlirContextEnableMultithreading(self.get(), enable); - }) - .def("is_registered_operation", - [](PyMlirContext &self, std::string &name) { - return mlirContextIsRegisteredOperation( - self.get(), MlirStringRef{name.data(), name.size()}); - }); + .def( + "enable_multithreading", + [](PyMlirContext &self, bool enable) { + mlirContextEnableMultithreading(self.get(), enable); + }, + py::arg("enable")) + .def( + "is_registered_operation", + [](PyMlirContext &self, std::string &name) { + return mlirContextIsRegisteredOperation( + self.get(), MlirStringRef{name.data(), name.size()}); + }, + py::arg("operation_name")); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -2013,7 +2024,7 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyDialect //---------------------------------------------------------------------------- py::class_(m, "Dialect", py::module_local()) - .def(py::init(), "descriptor") + .def(py::init(), py::arg("descriptor")) .def_property_readonly( "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) .def("__repr__", [](py::object self) { @@ -2332,7 +2343,7 @@ void mlir::python::populateIRCore(py::module &m) { auto opViewClass = py::class_(m, "OpView", py::module_local()) - .def(py::init()) + .def(py::init(), py::arg("operation")) .def_property_readonly("operation", &PyOpView::getOperationObject) .def_property_readonly( "context", @@ -2426,7 +2437,7 @@ void mlir::python::populateIRCore(py::module &m) { mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - py::arg("parent"), py::arg("pyArgTypes") = py::list(), + py::arg("parent"), py::arg("arg_types") = py::list(), "Creates and returns a new Block at the beginning of the given " "region (with given argument types).") .def( @@ -2499,6 +2510,7 @@ void mlir::python::populateIRCore(py::module &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, + py::arg("operation"), "Appends an operation to this block. If the operation is currently " "in another block, it will be moved."); @@ -2758,8 +2770,8 @@ void mlir::python::populateIRCore(py::module &m) { py::class_(m, "SymbolTable", py::module_local()) .def(py::init()) .def("__getitem__", &PySymbolTable::dunderGetItem) - .def("insert", &PySymbolTable::insert) - .def("erase", &PySymbolTable::erase) + .def("insert", &PySymbolTable::insert, py::arg("operation")) + .def("erase", &PySymbolTable::erase, py::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) .def("__contains__", [](PySymbolTable &table, const std::string &name) { return !mlirOperationIsNull(mlirSymbolTableLookup( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index dc024a247..f0d0cc654 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -244,8 +244,7 @@ class DefaultingPyMlirContext : public Defaulting { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - "[ThreadContextAware] mlir.ir.Context"; + static constexpr const char kTypeDescription[] = "mlir.ir.Context"; static PyMlirContext &resolve(); }; @@ -339,8 +338,7 @@ class DefaultingPyLocation : public Defaulting { public: using Defaulting::Defaulting; - static constexpr const char kTypeDescription[] = - "[ThreadContextAware] mlir.ir.Location"; + static constexpr const char kTypeDescription[] = "mlir.ir.Location"; static PyLocation &resolve(); operator MlirLocation() const { return *get(); } @@ -678,10 +676,14 @@ class PyConcreteType : public BaseTy { static void bind(pybind11::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyType &otherType) -> bool { - return DerivedTy::isaFunction(otherType); - }); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), + pybind11::arg("cast_from_type")); + cls.def_static( + "isinstance", + [](PyType &otherType) -> bool { + return DerivedTy::isaFunction(otherType); + }, + pybind11::arg("other")); DerivedTy::bindDerived(cls); } @@ -768,10 +770,14 @@ class PyConcreteAttribute : public BaseTy { static void bind(pybind11::module &m) { auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>()); - cls.def_static("isinstance", [](PyAttribute &otherAttr) -> bool { - return DerivedTy::isaFunction(otherAttr); - }); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), + pybind11::arg("cast_from_attr")); + cls.def_static( + "isinstance", + [](PyAttribute &otherAttr) -> bool { + return DerivedTy::isaFunction(otherAttr); + }, + pybind11::arg("other")); cls.def_property_readonly("type", [](PyAttribute &attr) { return PyType(attr.getContext(), mlirAttributeGetType(attr)); }); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 89fdb1f06..380aa36d7 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -262,6 +262,7 @@ class PyShapedType : public PyConcreteType { self.requireHasRank(); return mlirShapedTypeIsDynamicDim(self, dim); }, + py::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); c.def( @@ -270,10 +271,12 @@ class PyShapedType : public PyConcreteType { self.requireHasRank(); return mlirShapedTypeGetDimSize(self, dim); }, + py::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + py::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( @@ -282,6 +285,7 @@ class PyShapedType : public PyConcreteType { self.requireHasRank(); return mlirShapedTypeIsDynamicStrideOrOffset(val); }, + py::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); c.def_property_readonly( @@ -544,7 +548,7 @@ class PyTupleType : public PyConcreteType { MlirType t = mlirTupleTypeGetType(self, pos); return PyType(self.getContext(), t); }, - "Returns the pos-th type in the tuple type."); + py::arg("pos"), "Returns the pos-th type in the tuple type."); c.def_property_readonly( "num_types", [](PyTupleType &self) -> intptr_t { diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 5489a4d3e..896ee432e 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,14 +30,19 @@ PYBIND11_MODULE(_mlir, m) { .def_property("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", - [](PyGlobals &self, std::string moduleName) { - self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - self.clearImportCache(); - }) + .def( + "append_dialect_search_prefix", + [](PyGlobals &self, std::string moduleName) { + self.getDialectSearchPrefixes().push_back(std::move(moduleName)); + self.clearImportCache(); + }, + py::arg("module_name")) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, + py::arg("dialect_namespace"), py::arg("dialect_class"), "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, + py::arg("operation_name"), py::arg("operation_class"), + py::arg("raw_opview_class"), "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -55,6 +60,7 @@ PYBIND11_MODULE(_mlir, m) { PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, + py::arg("dialect_class"), "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", @@ -78,7 +84,9 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "Class decorator for registering a custom Operation wrapper"); + py::arg("dialect_class"), + "Produce a class decorator for registering an Operation class as part of " + "a dialect"); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 6aa1c651c..2c38a3a25 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -79,7 +79,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager, bool enable) { mlirPassManagerEnableVerifier(passManager.get(), enable); }, - "Enable / disable verify-each.") + py::arg("enable"), "Enable / disable verify-each.") .def_static( "parse", [](const std::string pipeline, DefaultingPyMlirContext context) { @@ -106,6 +106,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { throw SetPyError(PyExc_RuntimeError, "Failure while executing pass pipeline."); }, + py::arg("module"), "Run the pass manager on the provided module, throw a RuntimeError " "on failure.") .def( diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index cdb2e5a3c..f7b84b033 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -20,6 +20,11 @@ declare_mlir_python_sources(MLIRPythonSources.Core ir.py passmanager.py dialects/_ods_common.py + + # The main _mlir module has submodules: include stubs from each. + _mlir_libs/_mlir/__init__.pyi + _mlir_libs/_mlir/ir.pyi + _mlir_libs/_mlir/passmanager.pyi ) declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine @@ -27,6 +32,7 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine ADD_TO_PARENT MLIRPythonSources SOURCES execution_engine.py + _mlir_libs/_mlirExecutionEngine.pyi SOURCES_GLOB runtime/*.py ) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi new file mode 100644 index 000000000..d4aab6806 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -0,0 +1,13 @@ +from typing import List + +globals: _Globals + +class _Globals: + dialect_search_modules: List[str] + def __init__(self, *args, **kwargs) -> None: ... + def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: object, raw_opview_class: object) -> None: ... + def append_dialect_search_prefix(self, module_name: str) -> None: ... + +def register_dialect(dialect_class: object) -> object: ... +def register_operation(dialect_class: object) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi new file mode 100644 index 000000000..47ebeb291 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -0,0 +1,863 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlir.ir +# Local modifications: +# * Rewrite references to 'mlir.ir.' to local types +# * Add __all__ with the following incantation: +# egrep '^class ' ir.pyi | awk -F ' |:|\\(' '{print " \"" $2 "\","}' +# * Local edits to signatures and types that MyPy did not auto detect (or +# detected incorrectly). + +from typing import Any, ClassVar, List, Optional + +from typing import overload + +__all__ = [ + "AffineAddExpr", + "AffineBinaryExpr", + "AffineCeilDivExpr", + "AffineConstantExpr", + "AffineDimExpr", + "AffineExpr", + "AffineExprList", + "AffineFloorDivExpr", + "AffineMap", + "AffineMapAttr", + "AffineModExpr", + "AffineMulExpr", + "AffineSymbolExpr", + "ArrayAttr", + "ArrayAttributeIterator", + "Attribute", + "BF16Type", + "Block", + "BlockArgument", + "BlockArgumentList", + "BlockIterator", + "BlockList", + "BoolAttr", + "ComplexType", + "Context", + "DenseElementsAttr", + "DenseFPElementsAttr", + "DenseIntElementsAttr", + "Dialect", + "DialectDescriptor", + "Dialects", + "DictAttr", + "F16Type", + "F32Type", + "F64Type", + "FlatSymbolRefAttr", + "FloatAttr", + "FunctionType", + "IndexType", + "InferTypeOpInterface", + "InsertionPoint", + "IntegerAttr", + "IntegerSet", + "IntegerSetConstraint", + "IntegerSetConstraintList", + "IntegerType", + "Location", + "MemRefType", + "Module", + "NamedAttribute", + "NoneType", + "OpAttributeMap", + "OpOperandList", + "OpResult", + "OpResultList", + "OpView", + "Operation", + "OperationIterator", + "OperationList", + "RankedTensorType", + "Region", + "RegionIterator", + "RegionSequence", + "ShapedType", + "StringAttr", + "SymbolTable", + "TupleType", + "Type", + "TypeAttr", + "UnitAttr", + "UnrankedMemRefType", + "UnrankedTensorType", + "Value", + "VectorType", + "_GlobalDebug", + "_OperationBase", +] + + +class AffineAddExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineBinaryExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def lhs(self) -> AffineExpr: ... + @property + def rhs(self) -> AffineExpr: ... + +class AffineCeilDivExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineConstantExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def value(self) -> int: ... + +class AffineDimExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def position(self) -> int: ... + +class AffineExpr: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> AffineExpr: ... + def compose(self, arg0) -> AffineExpr: ... + def dump(self) -> None: ... + def get_add(self, *args, **kwargs) -> Any: ... + def get_ceil_div(self, *args, **kwargs) -> Any: ... + def get_constant(self, *args, **kwargs) -> Any: ... + def get_dim(self, *args, **kwargs) -> Any: ... + def get_floor_div(self, *args, **kwargs) -> Any: ... + def get_mod(self, *args, **kwargs) -> Any: ... + def get_mul(self, *args, **kwargs) -> Any: ... + def get_symbol(self, *args, **kwargs) -> Any: ... + def __add__(self, other) -> Any: ... + @overload + def __eq__(self, arg0: AffineExpr) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + def __mod__(self, other) -> Any: ... + def __mul__(self, other) -> Any: ... + def __radd__(self, other) -> Any: ... + def __rmod__(self, other) -> Any: ... + def __rmul__(self, other) -> Any: ... + def __rsub__(self, other) -> Any: ... + def __sub__(self, other) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + +class AffineExprList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ... + @overload + def __getitem__(self, arg0: int) -> AffineExpr: ... + @overload + def __getitem__(self, arg0: slice) -> AffineExprList: ... + def __len__(self) -> int: ... + +class AffineFloorDivExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineMap: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> AffineMap: ... + def compress_unused_symbols(self, *args, **kwargs) -> Any: ... + def dump(self) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def get_constant(self, *args, **kwargs) -> Any: ... + def get_empty(self, *args, **kwargs) -> Any: ... + def get_identity(self, *args, **kwargs) -> Any: ... + def get_major_submap(self, n_results: int) -> AffineMap: ... + def get_minor_identity(self, *args, **kwargs) -> Any: ... + def get_minor_submap(self, n_results: int) -> AffineMap: ... + def get_permutation(self, *args, **kwargs) -> Any: ... + def get_submap(self, result_positions: List[int]) -> AffineMap: ... + def replace(self, expr: AffineExpr, replacement: AffineExpr, n_result_dims: int, n_result_syms: int) -> AffineMap: ... + @overload + def __eq__(self, arg0: AffineMap) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + @property + def is_permutation(self) -> bool: ... + @property + def is_projected_permutation(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + @property + def results(self) -> Any: ... + +class AffineMapAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + +class AffineModExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineMulExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineSymbolExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def position(self) -> int: ... + +class ArrayAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def __add__(self, arg0: list) -> ArrayAttr: ... + def __getitem__(self, arg0: int) -> Attribute: ... + def __iter__(self) -> Any: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + +class ArrayAttributeIterator: + def __init__(self, *args, **kwargs) -> None: ... + def __iter__(self) -> ArrayAttributeIterator: ... + def __next__(self) -> Attribute: ... + +class Attribute: + def __init__(self, cast_from_type: Attribute) -> None: ... + def _CAPICreate(self) -> Attribute: ... + def dump(self) -> None: ... + def get_named(self, *args, **kwargs) -> Any: ... + def parse(self, *args, **kwargs) -> Any: ... + @overload + def __eq__(self, arg0: Attribute) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + @property + def type(self) -> Any: ... + +class BF16Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class Block: + __hash__: ClassVar[None] = ... + def __init__(self, *args, **kwargs) -> None: ... + def append(self, operation: _OperationBase) -> None: ... + def create_after(self, *args) -> Block: ... + def create_at_start(self, *args, **kwargs) -> Any: ... + def create_before(self, *args) -> Block: ... + @overload + def __eq__(self, arg0: Block) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __iter__(self) -> Any: ... + @property + def arguments(self) -> Any: ... + @property + def operations(self) -> Any: ... + @property + def owner(self) -> object: ... + @property + def region(self) -> Region: ... + +class BlockArgument(Value): + def __init__(self, value: Value) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def set_type(self, type: Type) -> None: ... + @property + def arg_number(self) -> int: ... + @property + def owner(self) -> Block: ... + +class BlockArgumentList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... + @overload + def __getitem__(self, arg0: int) -> BlockArgument: ... + @overload + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __len__(self) -> int: ... + @property + def types(self) -> List[Type]: ... + +class BlockIterator: + def __init__(self, *args, **kwargs) -> None: ... + def __iter__(self) -> BlockIterator: ... + def __next__(self) -> Block: ... + +class BlockList: + def __init__(self, *args, **kwargs) -> None: ... + def append(self, *args) -> Block: ... + def __getitem__(self, arg0: int) -> Block: ... + def __iter__(self) -> BlockIterator: ... + def __len__(self) -> int: ... + +class BoolAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> bool: ... + +class ComplexType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def element_type(self) -> Type: ... + +class Context: + current: ClassVar[Context] = ... # read-only + allow_unregistered_dialects: bool + def __init__(self) -> None: ... + def _CAPICreate(self) -> object: ... + def _get_context_again(self) -> object: ... + def _get_live_count(self, *args, **kwargs) -> Any: ... + def _get_live_module_count(self) -> int: ... + def _get_live_operation_count(self) -> int: ... + def enable_multithreading(self, enable: bool) -> None: ... + def get_dialect_descriptor(self, *args, **kwargs) -> Any: ... + def is_registered_operation(self, operation_name: str) -> bool: ... + def __enter__(self) -> object: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def d(self) -> Any: ... + @property + def dialects(self) -> Any: ... + +class DenseElementsAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def get_splat(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def __len__(self) -> int: ... + @property + def is_splat(self) -> bool: ... + @property + def type(self) -> Type: ... + +class DenseFPElementsAttr(DenseElementsAttr): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def __getitem__(self, arg0: int) -> float: ... + @property + def type(self) -> Type: ... + +class DenseIntElementsAttr(DenseElementsAttr): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def __getitem__(self, arg0: int) -> int: ... + @property + def type(self) -> Type: ... + +class Dialect: + def __init__(self, descriptor: object) -> None: ... + @property + def descriptor(self) -> object: ... + +class DialectDescriptor: + def __init__(self, *args, **kwargs) -> None: ... + @property + def namespace(self) -> str: ... + +class Dialects: + def __init__(self, *args, **kwargs) -> None: ... + def __getattr__(self, arg0: str) -> object: ... + def __getitem__(self, arg0: str) -> object: ... + +class DictAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + def __contains__(self, arg0: str) -> bool: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __len__(self) -> int: ... + @property + def type(self) -> Type: ... + +class F16Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class F32Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class F64Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class FlatSymbolRefAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> str: ... + +class FloatAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def get_f32(self, *args, **kwargs) -> Any: ... + def get_f64(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> float: ... + +class FunctionType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def inputs(self) -> list: ... + @property + def results(self) -> list: ... + +class IndexType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class InferTypeOpInterface: + def __init__(self, object: object, context: Context = ...) -> None: ... + def inferReturnTypes(self, operands: Optional[List[Value]] = ..., attributes: Optional[Attribute] = ..., regions: Optional[List[Region]] = ..., context: Context = ..., loc: Location = ...) -> List[Type]: ... + @property + def operation(self) -> object: ... + @property + def opview(self) -> object: ... + +class InsertionPoint: + current: ClassVar[InsertionPoint] = ... # read-only + @overload + def __init__(self, block: Block) -> None: ... + @overload + def __init__(self, beforeOperation: _OperationBase) -> None: ... + def at_block_begin(self, *args, **kwargs) -> Any: ... + def at_block_terminator(self, *args, **kwargs) -> Any: ... + def insert(self, operation: _OperationBase) -> None: ... + def __enter__(self) -> object: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + @property + def block(self) -> Block: ... + +class IntegerAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> int: ... + +class IntegerSet: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> IntegerSet: ... + def dump(self) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def get_empty(self, *args, **kwargs) -> Any: ... + def get_replaced(self, dim_exprs: list, symbol_exprs: list, num_result_dims: int, num_result_symbols: int) -> IntegerSet: ... + @overload + def __eq__(self, arg0: IntegerSet) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def constraints(self) -> Any: ... + @property + def context(self) -> object: ... + @property + def is_canonical_empty(self) -> bool: ... + @property + def n_dims(self) -> int: ... + @property + def n_equalities(self) -> int: ... + @property + def n_inequalities(self) -> int: ... + @property + def n_inputs(self) -> int: ... + @property + def n_symbols(self) -> int: ... + +class IntegerSetConstraint: + def __init__(self, *args, **kwargs) -> None: ... + @property + def expr(self) -> AffineExpr: ... + @property + def is_eq(self) -> bool: ... + +class IntegerSetConstraintList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: IntegerSetConstraintList) -> List[IntegerSetConstraint]: ... + @overload + def __getitem__(self, arg0: int) -> IntegerSetConstraint: ... + @overload + def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... + def __len__(self) -> int: ... + +class IntegerType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get_signed(self, *args, **kwargs) -> Any: ... + def get_signless(self, *args, **kwargs) -> Any: ... + def get_unsigned(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def is_signed(self) -> bool: ... + @property + def is_signless(self) -> bool: ... + @property + def is_unsigned(self) -> bool: ... + @property + def width(self) -> int: ... + +class Location: + current: ClassVar[Location] = ... # read-only + __hash__: ClassVar[None] = ... + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> Location: ... + def callsite(self, *args, **kwargs) -> Any: ... + def file(self, *args, **kwargs) -> Any: ... + def name(self, *args, **kwargs) -> Any: ... + def unknown(self, *args, **kwargs) -> Any: ... + def __enter__(self) -> object: ... + @overload + def __eq__(self, arg0: Location) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + +class MemRefType(ShapedType): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def affine_map(self) -> AffineMap: ... + @property + def layout(self) -> Attribute: ... + @property + def memory_space(self) -> Attribute: ... + +class Module: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> object: ... + def create(self, *args, **kwargs) -> Any: ... + def dump(self) -> None: ... + def parse(self, *args, **kwargs) -> Any: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def body(self) -> Any: ... + @property + def context(self) -> object: ... + @property + def operation(self) -> object: ... + +class NamedAttribute: + def __init__(self, *args, **kwargs) -> None: ... + @property + def attr(self) -> Attribute: ... + @property + def name(self) -> str: ... + +class NoneType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class OpAttributeMap: + def __init__(self, *args, **kwargs) -> None: ... + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + @overload + def __getitem__(self, arg0: str) -> Attribute: ... + @overload + def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... + +class OpOperandList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: OpOperandList) -> List[Value]: ... + @overload + def __getitem__(self, arg0: int) -> Value: ... + @overload + def __getitem__(self, arg0: slice) -> OpOperandList: ... + def __len__(self) -> int: ... + def __setitem__(self, arg0: int, arg1: Value) -> None: ... + +class OpResult(Value): + def __init__(self, value: Value) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def owner(self) -> object: ... + @property + def result_number(self) -> int: ... + +class OpResultList: + def __init__(self, *args, **kwargs) -> None: ... + def __add__(self, arg0: OpResultList) -> List[OpResult]: ... + @overload + def __getitem__(self, arg0: int) -> OpResult: ... + @overload + def __getitem__(self, arg0: slice) -> OpResultList: ... + def __len__(self) -> int: ... + @property + def types(self) -> List[Type]: ... + +class OpView(_OperationBase): + _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... + _ODS_REGIONS: ClassVar[tuple] = ... + _ODS_RESULT_SEGMENTS: ClassVar[None] = ... + def __init__(self, operation: object) -> None: ... + @classmethod + def build_generic(self, *args, **kwargs) -> Any: ... + @property + def context(self) -> object: ... + @property + def operation(self) -> object: ... + +class Operation(_OperationBase): + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> object: ... + def create(self, *args, **kwargs) -> Any: ... + def erase(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + @property + def name(self) -> str: ... + @property + def opview(self) -> object: ... + @property + def parent(self) -> object: ... + +class OperationIterator: + def __init__(self, *args, **kwargs) -> None: ... + def __iter__(self) -> OperationIterator: ... + def __next__(self) -> object: ... + +class OperationList: + def __init__(self, *args, **kwargs) -> None: ... + def __getitem__(self, arg0: int) -> object: ... + def __iter__(self) -> OperationIterator: ... + def __len__(self) -> int: ... + +class RankedTensorType(ShapedType): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def encoding(self) -> Optional[Attribute]: ... + +class Region: + __hash__: ClassVar[None] = ... + def __init__(self, *args, **kwargs) -> None: ... + @overload + def __eq__(self, arg0: Region) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __iter__(self) -> Any: ... + @property + def blocks(self) -> Any: ... + @property + def owner(self) -> object: ... + +class RegionIterator: + def __init__(self, *args, **kwargs) -> None: ... + def __iter__(self) -> RegionIterator: ... + def __next__(self) -> Region: ... + +class RegionSequence: + def __init__(self, *args, **kwargs) -> None: ... + def __getitem__(self, arg0: int) -> Region: ... + def __len__(self) -> int: ... + +class ShapedType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get_dim_size(self, dim: int) -> int: ... + def is_dynamic_dim(self, dim: int) -> bool: ... + def is_dynamic_size(self, *args, **kwargs) -> Any: ... + def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def element_type(self) -> Type: ... + @property + def has_rank(self) -> bool: ... + @property + def has_static_shape(self) -> bool: ... + @property + def rank(self) -> int: ... + @property + def shape(self) -> List[int]: ... + +class StringAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def get_typed(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> str: ... + +class SymbolTable: + def __init__(self, arg0: _OperationBase) -> None: ... + def erase(self, operation: _OperationBase) -> None: ... + def insert(self, operation: _OperationBase) -> Attribute: ... + def __contains__(self, arg0: str) -> bool: ... + def __delitem__(self, arg0: str) -> None: ... + def __getitem__(self, arg0: str) -> object: ... + +class TupleType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + def get_tuple(self, *args, **kwargs) -> Any: ... + def get_type(self, pos: int) -> Type: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def num_types(self) -> int: ... + +class Type: + def __init__(self, cast_from_type: Type) -> None: ... + def _CAPICreate(self) -> Type: ... + def dump(self) -> None: ... + def parse(self, *args, **kwargs) -> Any: ... + @overload + def __eq__(self, arg0: Type) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + +class TypeAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + @property + def value(self) -> Type: ... + +class UnitAttr(Attribute): + def __init__(self, cast_from_attr: Attribute) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def type(self) -> Type: ... + +class UnrankedMemRefType(ShapedType): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def memory_space(self) -> Attribute: ... + +class UnrankedTensorType(ShapedType): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class Value: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> Value: ... + def dump(self) -> None: ... + @overload + def __eq__(self, arg0: Value) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Any: ... + @property + def owner(self) -> object: ... + @property + def type(self) -> Type: ... + +class VectorType(ShapedType): + def __init__(self, cast_from_type: Type) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class _GlobalDebug: + flag: ClassVar[bool] = ... + def __init__(self, *args, **kwargs) -> None: ... + +class _OperationBase: + def __init__(self, *args, **kwargs) -> None: ... + def detach_from_parent(self) -> object: ... + def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ... + def move_after(self, other: _OperationBase) -> None: ... + def move_before(self, other: _OperationBase) -> None: ... + def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ... + def verify(self) -> bool: ... + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def attributes(self) -> Any: ... + @property + def location(self) -> Location: ... + @property + def operands(self) -> Any: ... + @property + def regions(self) -> Any: ... + @property + def result(self) -> Any: ... + @property + def results(self) -> Any: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi new file mode 100644 index 000000000..7003bf0f0 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -0,0 +1,24 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlir.passmanager +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + +from typing import Any + +from . import ir as _ir + +__all__ = [ + "PassManager", +] + +class PassManager: + def __init__(self, context: _ir.Context = ...) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def enable_ir_printing(self) -> None: ... + def enable_verifier(self, enable: bool) -> None: ... + def parse(self, *args, **kwargs) -> Any: ... + def run(self, module: _ir.Module) -> None: ... + @property + def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi new file mode 100644 index 000000000..50ff6c5b1 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -0,0 +1,23 @@ +# Originally imported via: +# stubgen {...} -m mlir._mlir_libs._mlirExecutionEngine +# Local modifications: +# * Relative imports for cross-module references. +# * Add __all__ + +from typing import List + +from ._mlir import ir as _ir + +__all__ = [ + "ExecutionEngine", +] + +class ExecutionEngine: + def __init__(self, module: _ir.Module, opt_level: int = ..., shared_libs: List[str] = ...) -> None: ... + def _CAPICreate(self) -> object: ... + def _testing_release(self) -> None: ... + def dump_to_object_file(self, file_name: str) -> None: ... + def raw_lookup(self, func_name: str) -> int: ... + def raw_register_runtime(self, name: str, callback: object) -> None: ... + @property + def _CAPIPtr(self) -> object: ... From 81bcaf59d97c367f3da1bcb024b989eff57eda92 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 28 Nov 2021 20:30:18 -0800 Subject: [PATCH 176/915] [mlir][python] Implement more SymbolTable methods. * set_symbol_name, get_symbol_name, set_visibility, get_visibility, replace_all_symbol_uses, walk_symbol_tables * In integrations I've been doing, I've been reaching for all of these to do both general IR manipulation and module merging. * I don't love the replace_all_symbol_uses underlying APIs since they necessitate SYMBOL_COUNT walks and have various sharp edges. I'm hoping that whatever emerges eventually for this can still retain this simple API as a one-shot. Differential Revision: https://reviews.llvm.org/D114687 --- mlir/include/mlir-c/IR.h | 20 ++ mlir/lib/Bindings/Python/IRCore.cpp | 130 ++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 19 ++ mlir/lib/CAPI/IR/IR.cpp | 26 +++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 230 ++++++++++++----------- 5 files changed, 313 insertions(+), 112 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 6c1a92cea..1d884e634 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -754,6 +754,9 @@ MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); /// symbol tables. MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(); +/// Returns the name of the attribute used to store symbol visibility. +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(); + /// Creates a symbol table for the given operation. If the operation does not /// have the SymbolTable trait, returns a null symbol table. MLIR_CAPI_EXPORTED MlirSymbolTable @@ -787,6 +790,23 @@ mlirSymbolTableInsert(MlirSymbolTable symbolTable, MlirOperation operation); MLIR_CAPI_EXPORTED void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation); +/// Attempt to replace all uses that are nested within the given operation +/// of the given symbol 'oldSymbol' with the provided 'newSymbol'. This does +/// not traverse into nested symbol tables. Will fail atomically if there are +/// any unknown operations that may be potential symbol tables. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses( + MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from); + +/// Walks all symbol table operations nested within, and including, `op`. For +/// each symbol table operation, the provided callback is invoked with the op +/// and a boolean signifying if the symbols within that symbol table can be +/// treated as if all uses within the IR are visible to the caller. +/// `allSymUsesVisible` identifies whether all of the symbol uses of symbols +/// within `op` are visible. +MLIR_CAPI_EXPORTED void mlirSymbolTableWalkSymbolTables( + MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, void *userData), void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8a110fcc4..0d3491433 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1596,6 +1596,112 @@ PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); } +PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw py::value_error("Expected operation to have a symbol name."); + return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); +} + +void PySymbolTable::setSymbolName(PyOperationBase &symbol, + const std::string &name) { + // Op must already be a symbol. + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetSymbolAttributeName(); + MlirAttribute existingNameAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingNameAttr)) + throw py::value_error("Expected operation to have a symbol name."); + MlirAttribute newNameAttr = + mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); + mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); +} + +PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw py::value_error("Expected operation to have a symbol visibility."); + return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); +} + +void PySymbolTable::setVisibility(PyOperationBase &symbol, + const std::string &visibility) { + if (visibility != "public" && visibility != "private" && + visibility != "nested") + throw py::value_error( + "Expected visibility to be 'public', 'private' or 'nested'"); + PyOperation &operation = symbol.getOperation(); + operation.checkValid(); + MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); + MlirAttribute existingVisAttr = + mlirOperationGetAttributeByName(operation.get(), attrName); + if (mlirAttributeIsNull(existingVisAttr)) + throw py::value_error("Expected operation to have a symbol visibility."); + MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), + toMlirStringRef(visibility)); + mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); +} + +void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + if (mlirLogicalResultIsFailure(mlirSymbolTableReplaceAllSymbolUses( + toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), + from.getOperation()))) + + throw py::value_error("Symbol rename failed"); +} + +void PySymbolTable::walkSymbolTables(PyOperationBase &from, + bool allSymUsesVisible, + py::object callback) { + PyOperation &fromOperation = from.getOperation(); + fromOperation.checkValid(); + struct UserData { + PyMlirContextRef context; + py::object callback; + bool gotException; + std::string exceptionWhat; + py::object exceptionType; + }; + UserData userData{ + fromOperation.getContext(), std::move(callback), false, {}, {}}; + mlirSymbolTableWalkSymbolTables( + fromOperation.get(), allSymUsesVisible, + [](MlirOperation foundOp, bool isVisible, void *calleeUserDataVoid) { + UserData *calleeUserData = static_cast(calleeUserDataVoid); + auto pyFoundOp = + PyOperation::forOperation(calleeUserData->context, foundOp); + if (calleeUserData->gotException) + return; + try { + calleeUserData->callback(pyFoundOp.getObject(), isVisible); + } catch (py::error_already_set &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); + } + }, + static_cast(&userData)); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(std::move(message)); + } +} + namespace { /// CRTP base class for Python MLIR values that subclass Value and should be /// castable from it. The value hierarchy is one level deep and is not supposed @@ -2773,10 +2879,26 @@ void mlir::python::populateIRCore(py::module &m) { .def("insert", &PySymbolTable::insert, py::arg("operation")) .def("erase", &PySymbolTable::erase, py::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) - .def("__contains__", [](PySymbolTable &table, const std::string &name) { - return !mlirOperationIsNull(mlirSymbolTableLookup( - table, mlirStringRefCreate(name.data(), name.length()))); - }); + .def("__contains__", + [](PySymbolTable &table, const std::string &name) { + return !mlirOperationIsNull(mlirSymbolTableLookup( + table, mlirStringRefCreate(name.data(), name.length()))); + }) + // Static helpers. + .def_static("set_symbol_name", &PySymbolTable::setSymbolName, + py::arg("symbol"), py::arg("name")) + .def_static("get_symbol_name", &PySymbolTable::getSymbolName, + py::arg("symbol")) + .def_static("get_visibility", &PySymbolTable::getVisibility, + py::arg("symbol")) + .def_static("set_visibility", &PySymbolTable::setVisibility, + py::arg("symbol"), py::arg("visibility")) + .def_static("replace_all_symbol_uses", + &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), + py::arg("new_symbol"), py::arg("from_op")) + .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, + py::arg("from_op"), py::arg("all_sym_uses_visible"), + py::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f0d0cc654..d5e8eb4ae 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -910,6 +910,25 @@ class PySymbolTable { /// the symbol trait. PyAttribute insert(PyOperationBase &symbol); + /// Gets and sets the name of a symbol op. + static PyAttribute getSymbolName(PyOperationBase &symbol); + static void setSymbolName(PyOperationBase &symbol, const std::string &name); + + /// Gets and sets the visibility of a symbol op. + static PyAttribute getVisibility(PyOperationBase &symbol); + static void setVisibility(PyOperationBase &symbol, + const std::string &visibility); + + /// Replaces all symbol uses within an operation. See the API + /// mlirSymbolTableReplaceAllSymbolUses for all caveats. + static void replaceAllSymbolUses(const std::string &oldSymbol, + const std::string &newSymbol, + PyOperationBase &from); + + /// Walks all symbol tables under and including 'from'. + static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, + pybind11::object callback); + /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 35a059275..424bbae17 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -786,6 +786,10 @@ MlirStringRef mlirSymbolTableGetSymbolAttributeName() { return wrap(SymbolTable::getSymbolAttrName()); } +MlirStringRef mlirSymbolTableGetVisibilityAttributeName() { + return wrap(SymbolTable::getVisibilityAttrName()); +} + MlirSymbolTable mlirSymbolTableCreate(MlirOperation operation) { if (!unwrap(operation)->hasTrait()) return wrap(static_cast(nullptr)); @@ -810,3 +814,25 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirOperation operation) { unwrap(symbolTable)->erase(unwrap(operation)); } + +MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, + MlirStringRef newSymbol, + MlirOperation from) { + auto cppFrom = unwrap(from); + auto *context = cppFrom->getContext(); + auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context); + auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context); + return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, + unwrap(from))); +} + +void mlirSymbolTableWalkSymbolTables(MlirOperation from, bool allSymUsesVisible, + void (*callback)(MlirOperation, bool, + void *userData), + void *userData) { + SymbolTable::walkSymbolTables(unwrap(from), allSymUsesVisible, + [&](Operation *foundOpCpp, bool isVisible) { + callback(wrap(foundOpCpp), isVisible, + userData); + }); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 47ebeb291..3c7653feb 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,7 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, ClassVar, List, Optional +from typing import Any, Callable, ClassVar, List, Optional from typing import overload @@ -90,38 +90,34 @@ __all__ = [ "_OperationBase", ] - -class AffineAddExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... - -class AffineBinaryExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def isinstance(self, *args, **kwargs) -> Any: ... +# Base classes: declared first to simplify declarations below. +class _OperationBase: + def __init__(self, *args, **kwargs) -> None: ... + def detach_from_parent(self) -> object: ... + def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ... + def move_after(self, other: _OperationBase) -> None: ... + def move_before(self, other: _OperationBase) -> None: ... + def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ... + def verify(self) -> bool: ... + @overload + def __eq__(self, arg0: _OperationBase) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... @property - def lhs(self) -> AffineExpr: ... + def _CAPIPtr(self) -> object: ... @property - def rhs(self) -> AffineExpr: ... - -class AffineCeilDivExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... - -class AffineConstantExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + def attributes(self) -> Any: ... @property - def value(self) -> int: ... - -class AffineDimExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + def location(self) -> Location: ... @property - def position(self) -> int: ... + def operands(self) -> Any: ... + @property + def regions(self) -> Any: ... + @property + def result(self) -> Any: ... + @property + def results(self) -> Any: ... class AffineExpr: def __init__(self, *args, **kwargs) -> None: ... @@ -154,6 +150,91 @@ class AffineExpr: @property def context(self) -> object: ... +class Attribute: + def __init__(self, cast_from_type: Attribute) -> None: ... + def _CAPICreate(self) -> Attribute: ... + def dump(self) -> None: ... + def get_named(self, *args, **kwargs) -> Any: ... + def parse(self, *args, **kwargs) -> Any: ... + @overload + def __eq__(self, arg0: Attribute) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + @property + def type(self) -> Any: ... + +class Type: + def __init__(self, cast_from_type: Type) -> None: ... + def _CAPICreate(self) -> Type: ... + def dump(self) -> None: ... + def parse(self, *args, **kwargs) -> Any: ... + @overload + def __eq__(self, arg0: Type) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> object: ... + +class Value: + def __init__(self, *args, **kwargs) -> None: ... + def _CAPICreate(self) -> Value: ... + def dump(self) -> None: ... + @overload + def __eq__(self, arg0: Value) -> bool: ... + @overload + def __eq__(self, arg0: object) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + @property + def context(self) -> Any: ... + @property + def owner(self) -> object: ... + @property + def type(self) -> Type: ... + + +# Classes with no particular order sensitivity in alpha order. +class AffineAddExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineBinaryExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def lhs(self) -> AffineExpr: ... + @property + def rhs(self) -> AffineExpr: ... + +class AffineCeilDivExpr(AffineBinaryExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + +class AffineConstantExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def value(self) -> int: ... + +class AffineDimExpr(AffineExpr): + def __init__(self, expr: AffineExpr) -> None: ... + def get(self, *args, **kwargs) -> Any: ... + def isinstance(self, *args, **kwargs) -> Any: ... + @property + def position(self) -> int: ... + class AffineExprList: def __init__(self, *args, **kwargs) -> None: ... def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ... @@ -245,24 +326,6 @@ class ArrayAttributeIterator: def __iter__(self) -> ArrayAttributeIterator: ... def __next__(self) -> Attribute: ... -class Attribute: - def __init__(self, cast_from_type: Attribute) -> None: ... - def _CAPICreate(self) -> Attribute: ... - def dump(self) -> None: ... - def get_named(self, *args, **kwargs) -> Any: ... - def parse(self, *args, **kwargs) -> Any: ... - @overload - def __eq__(self, arg0: Attribute) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> object: ... - @property - def type(self) -> Any: ... - class BF16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -751,7 +814,19 @@ class StringAttr(Attribute): class SymbolTable: def __init__(self, arg0: _OperationBase) -> None: ... def erase(self, operation: _OperationBase) -> None: ... + @staticmethod + def get_symbol_name(symbol: _OperationBase) -> Attribute: ... + @staticmethod + def get_visibility(symbol: _OperationBase) -> Attribute: ... def insert(self, operation: _OperationBase) -> Attribute: ... + @staticmethod + def replace_all_symbol_uses(old_symbol: str, new_symbol: str, from_op: _OperationBase) -> None: ... + @staticmethod + def set_symbol_name(symbol: _OperationBase, name: str) -> None: ... + @staticmethod + def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... + @staticmethod + def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None) -> None: ... def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... def __getitem__(self, arg0: str) -> object: ... @@ -764,21 +839,6 @@ class TupleType(Type): @property def num_types(self) -> int: ... -class Type: - def __init__(self, cast_from_type: Type) -> None: ... - def _CAPICreate(self) -> Type: ... - def dump(self) -> None: ... - def parse(self, *args, **kwargs) -> Any: ... - @overload - def __eq__(self, arg0: Type) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> object: ... - class TypeAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -807,24 +867,6 @@ class UnrankedTensorType(ShapedType): def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... -class Value: - def __init__(self, *args, **kwargs) -> None: ... - def _CAPICreate(self) -> Value: ... - def dump(self) -> None: ... - @overload - def __eq__(self, arg0: Value) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def context(self) -> Any: ... - @property - def owner(self) -> object: ... - @property - def type(self) -> Type: ... - class VectorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -833,31 +875,3 @@ class VectorType(ShapedType): class _GlobalDebug: flag: ClassVar[bool] = ... def __init__(self, *args, **kwargs) -> None: ... - -class _OperationBase: - def __init__(self, *args, **kwargs) -> None: ... - def detach_from_parent(self) -> object: ... - def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ... - def move_after(self, other: _OperationBase) -> None: ... - def move_before(self, other: _OperationBase) -> None: ... - def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ... - def verify(self) -> bool: ... - @overload - def __eq__(self, arg0: _OperationBase) -> bool: ... - @overload - def __eq__(self, arg0: object) -> bool: ... - def __hash__(self) -> int: ... - @property - def _CAPIPtr(self) -> object: ... - @property - def attributes(self) -> Any: ... - @property - def location(self) -> Location: ... - @property - def operands(self) -> Any: ... - @property - def regions(self) -> Any: ... - @property - def result(self) -> Any: ... - @property - def results(self) -> Any: ... From b5a553552978cec567e3669d744036cd7baeb20b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 29 Nov 2021 21:39:03 -0800 Subject: [PATCH 177/915] [mlir][python] Audit and fix a lot of the Python pyi stubs. * Classes that are still todo are marked with "# TODO: Auto-generated. Audit and fix." * Those without this note have been cross-checked with C++ sources and most have been spot checked by hovering in VsCode. Differential Revision: https://reviews.llvm.org/D114767 --- .../python/mlir/_mlir_libs/_mlir/__init__.pyi | 11 +- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 317 ++++++++++-------- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 7 +- .../mlir/_mlir_libs/_mlirExecutionEngine.pyi | 4 +- 4 files changed, 194 insertions(+), 145 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index d4aab6806..c8734cfde 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -1,13 +1,12 @@ from typing import List -globals: _Globals +globals: "_Globals" class _Globals: dialect_search_modules: List[str] - def __init__(self, *args, **kwargs) -> None: ... - def _register_dialect_impl(self, dialect_namespace: str, dialect_class: object) -> None: ... - def _register_operation_impl(self, operation_name: str, operation_class: object, raw_opview_class: object) -> None: ... + def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ... def append_dialect_search_prefix(self, module_name: str) -> None: ... -def register_dialect(dialect_class: object) -> object: ... -def register_operation(dialect_class: object) -> object: ... +def register_dialect(dialect_class: type) -> object: ... +def register_operation(dialect_class: type) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 3c7653feb..e1a84ddfd 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,7 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, Callable, ClassVar, List, Optional +from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence from typing import overload @@ -92,33 +92,33 @@ __all__ = [ # Base classes: declared first to simplify declarations below. class _OperationBase: - def __init__(self, *args, **kwargs) -> None: ... - def detach_from_parent(self) -> object: ... - def get_asm(self, binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> object: ... - def move_after(self, other: _OperationBase) -> None: ... - def move_before(self, other: _OperationBase) -> None: ... - def print(self, file: object = ..., binary: bool = ..., large_elements_limit: Optional[int] = ..., enable_debug_info: bool = ..., pretty_debug_info: bool = ..., print_generic_op_form: bool = ..., use_local_scope: bool = ...) -> None: ... + def detach_from_parent(self) -> "OpView": ... + def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ... + def move_after(self, other: "_OperationBase") -> None: ... + def move_before(self, other: "_OperationBase") -> None: ... + def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ... def verify(self) -> bool: ... @overload - def __eq__(self, arg0: _OperationBase) -> bool: ... + def __eq__(self, arg0: "_OperationBase") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def attributes(self) -> Any: ... + def attributes(self) -> "OpAttributeMap": ... @property - def location(self) -> Location: ... + def location(self) -> "Location": ... @property - def operands(self) -> Any: ... + def operands(self) -> "OpOperandList": ... @property - def regions(self) -> Any: ... + def regions(self) -> "RegionSequence": ... @property - def result(self) -> Any: ... + def result(self) -> "OpResult": ... @property - def results(self) -> Any: ... + def results(self) -> "OpResultList": ... +# TODO: Auto-generated. Audit and fix. class AffineExpr: def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> AffineExpr: ... @@ -151,63 +151,66 @@ class AffineExpr: def context(self) -> object: ... class Attribute: - def __init__(self, cast_from_type: Attribute) -> None: ... - def _CAPICreate(self) -> Attribute: ... + def __init__(self, cast_from_type: "Attribute") -> None: ... + def _CAPICreate(self) -> "Attribute": ... def dump(self) -> None: ... def get_named(self, *args, **kwargs) -> Any: ... - def parse(self, *args, **kwargs) -> Any: ... + @staticmethod + def parse(asm: str, context: Optional["Context"] = None) -> Any: ... @overload - def __eq__(self, arg0: Attribute) -> bool: ... + def __eq__(self, arg0: "Attribute") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> "Context"": ... @property - def type(self) -> Any: ... + def type(self) -> "Type": ... class Type: - def __init__(self, cast_from_type: Type) -> None: ... + def __init__(self, cast_from_type: "Type") -> None: ... def _CAPICreate(self) -> Type: ... def dump(self) -> None: ... - def parse(self, *args, **kwargs) -> Any: ... + @staticmethod + def parse(asm: str, context: Optional["Context"] = None) -> "Type": ... @overload - def __eq__(self, arg0: Type) -> bool: ... + def __eq__(self, arg0: "Type") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> "Context": ... class Value: - def __init__(self, *args, **kwargs) -> None: ... - def _CAPICreate(self) -> Value: ... + def _CAPICreate(self) -> "Value": ... def dump(self) -> None: ... @overload - def __eq__(self, arg0: Value) -> bool: ... + def __eq__(self, arg0: "Value") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Any: ... + def context(self) -> "Context": ... @property - def owner(self) -> object: ... + def owner(self) -> "_OperationBase": ... @property - def type(self) -> Type: ... + def type(self) -> "Type": ... # Classes with no particular order sensitivity in alpha order. +# TODO: Auto-generated. Audit and fix. class AffineAddExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineBinaryExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... def isinstance(self, *args, **kwargs) -> Any: ... @@ -216,11 +219,13 @@ class AffineBinaryExpr(AffineExpr): @property def rhs(self) -> AffineExpr: ... +# TODO: Auto-generated. Audit and fix. class AffineCeilDivExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineConstantExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -228,6 +233,7 @@ class AffineConstantExpr(AffineExpr): @property def value(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class AffineDimExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -235,6 +241,7 @@ class AffineDimExpr(AffineExpr): @property def position(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class AffineExprList: def __init__(self, *args, **kwargs) -> None: ... def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ... @@ -244,11 +251,13 @@ class AffineExprList: def __getitem__(self, arg0: slice) -> AffineExprList: ... def __len__(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class AffineFloorDivExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineMap: def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> AffineMap: ... @@ -286,6 +295,7 @@ class AffineMap: @property def results(self) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineMapAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -293,16 +303,19 @@ class AffineMapAttr(Attribute): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class AffineModExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineMulExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class AffineSymbolExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -310,6 +323,7 @@ class AffineSymbolExpr(AffineExpr): @property def position(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class ArrayAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -321,11 +335,13 @@ class ArrayAttr(Attribute): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class ArrayAttributeIterator: def __init__(self, *args, **kwargs) -> None: ... def __iter__(self) -> ArrayAttributeIterator: ... def __next__(self) -> Attribute: ... +# TODO: Auto-generated. Audit and fix. class BF16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -333,57 +349,55 @@ class BF16Type(Type): class Block: __hash__: ClassVar[None] = ... - def __init__(self, *args, **kwargs) -> None: ... def append(self, operation: _OperationBase) -> None: ... - def create_after(self, *args) -> Block: ... - def create_at_start(self, *args, **kwargs) -> Any: ... - def create_before(self, *args) -> Block: ... + def create_after(self, *args: "Type") -> "Block": ... + @staticmethod + def create_at_start(parent: "Region", arg_types: List["Type"]) -> "Block": ... + def create_before(self, *args: "Type") -> "Block": ... @overload - def __eq__(self, arg0: Block) -> bool: ... + def __eq__(self, arg0: "Block") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __iter__(self) -> Any: ... @property - def arguments(self) -> Any: ... + def arguments(self) -> "BlockArgumentList": ... @property - def operations(self) -> Any: ... + def operations(self) -> "OperationList": ... @property - def owner(self) -> object: ... + def owner(self) -> "OpView": ... @property - def region(self) -> Region: ... + def region(self) -> "Region": ... class BlockArgument(Value): - def __init__(self, value: Value) -> None: ... def isinstance(self, *args, **kwargs) -> Any: ... - def set_type(self, type: Type) -> None: ... + def set_type(self, type: "Type") -> None: ... @property def arg_number(self) -> int: ... @property - def owner(self) -> Block: ... + def owner(self) -> "Block": ... class BlockArgumentList: - def __init__(self, *args, **kwargs) -> None: ... - def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... + def __add__(self, arg0: "BlockArgumentList") -> List["BlockArgument"]: ... @overload - def __getitem__(self, arg0: int) -> BlockArgument: ... + def __getitem__(self, arg0: int) -> "BlockArgument": ... @overload - def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __getitem__(self, arg0: slice) -> "BlockArgumentList": ... def __len__(self) -> int: ... @property - def types(self) -> List[Type]: ... + def types(self) -> List["Type"]: ... class BlockIterator: def __init__(self, *args, **kwargs) -> None: ... - def __iter__(self) -> BlockIterator: ... - def __next__(self) -> Block: ... + def __iter__(self) -> "BlockIterator": ... + def __next__(self) -> "Block": ... class BlockList: - def __init__(self, *args, **kwargs) -> None: ... - def append(self, *args) -> Block: ... - def __getitem__(self, arg0: int) -> Block: ... - def __iter__(self) -> BlockIterator: ... + def append(self, *args) -> "Block": ... + def __getitem__(self, arg0: int) -> "Block": ... + def __iter__(self) -> "BlockIterator": ... def __len__(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class BoolAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -393,6 +407,7 @@ class BoolAttr(Attribute): @property def value(self) -> bool: ... +# TODO: Auto-generated. Audit and fix. class ComplexType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -401,26 +416,28 @@ class ComplexType(Type): def element_type(self) -> Type: ... class Context: - current: ClassVar[Context] = ... # read-only + current: ClassVar["Context"] = ... # read-only allow_unregistered_dialects: bool def __init__(self) -> None: ... def _CAPICreate(self) -> object: ... - def _get_context_again(self) -> object: ... - def _get_live_count(self, *args, **kwargs) -> Any: ... + def _get_context_again(self) -> "Context": ... + @staticmethod + def _get_live_count() -> int: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(self, *args, **kwargs) -> Any: ... + def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ... def is_registered_operation(self, operation_name: str) -> bool: ... - def __enter__(self) -> object: ... + def __enter__(self) -> "Context": ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def d(self) -> Any: ... + def d(self) -> "Dialects": ... @property - def dialects(self) -> Any: ... + def dialects(self) -> "Dialects": ... +# TODO: Auto-generated. Audit and fix. class DenseElementsAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -432,6 +449,7 @@ class DenseElementsAttr(Attribute): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class DenseFPElementsAttr(DenseElementsAttr): def __init__(self, cast_from_attr: Attribute) -> None: ... def isinstance(self, *args, **kwargs) -> Any: ... @@ -439,6 +457,7 @@ class DenseFPElementsAttr(DenseElementsAttr): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class DenseIntElementsAttr(DenseElementsAttr): def __init__(self, cast_from_attr: Attribute) -> None: ... def isinstance(self, *args, **kwargs) -> Any: ... @@ -447,20 +466,20 @@ class DenseIntElementsAttr(DenseElementsAttr): def type(self) -> Type: ... class Dialect: - def __init__(self, descriptor: object) -> None: ... + def __init__(self, descriptor: "DialectDescriptor") -> None: ... @property - def descriptor(self) -> object: ... + def descriptor(self) -> "DialectDescriptor": ... class DialectDescriptor: - def __init__(self, *args, **kwargs) -> None: ... @property def namespace(self) -> str: ... class Dialects: def __init__(self, *args, **kwargs) -> None: ... - def __getattr__(self, arg0: str) -> object: ... - def __getitem__(self, arg0: str) -> object: ... + def __getattr__(self, arg0: str) -> "Dialect": ... + def __getitem__(self, arg0: str) -> "Dialect": ... +# TODO: Auto-generated. Audit and fix. class DictAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -474,21 +493,25 @@ class DictAttr(Attribute): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class F32Type(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class F64Type(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class FlatSymbolRefAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -498,6 +521,7 @@ class FlatSymbolRefAttr(Attribute): @property def value(self) -> str: ... +# TODO: Auto-generated. Audit and fix. class FloatAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -509,6 +533,7 @@ class FloatAttr(Attribute): @property def value(self) -> float: ... +# TODO: Auto-generated. Audit and fix. class FunctionType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -518,33 +543,37 @@ class FunctionType(Type): @property def results(self) -> list: ... +# TODO: Auto-generated. Audit and fix. class IndexType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... class InferTypeOpInterface: - def __init__(self, object: object, context: Context = ...) -> None: ... - def inferReturnTypes(self, operands: Optional[List[Value]] = ..., attributes: Optional[Attribute] = ..., regions: Optional[List[Region]] = ..., context: Context = ..., loc: Location = ...) -> List[Type]: ... + def __init__(self, object: object, context: Optional["Context"] = None) -> None: ... + def inferReturnTypes(self, operands: Optional[List["Value"]] = None, attributes: Optional["Attribute"] = None, regions: Optional[List["Region"]] = None, context: Optional["Context"] = None, loc: Optional["Location"] = None) -> List[Type]: ... @property - def operation(self) -> object: ... + def operation(self) -> "_OperationBase": ... @property - def opview(self) -> object: ... + def opview(self) -> "OpView": ... class InsertionPoint: - current: ClassVar[InsertionPoint] = ... # read-only + current: ClassVar["InsertionPoint"] = ... # read-only @overload - def __init__(self, block: Block) -> None: ... + def __init__(self, block: "Block") -> None: ... @overload - def __init__(self, beforeOperation: _OperationBase) -> None: ... - def at_block_begin(self, *args, **kwargs) -> Any: ... - def at_block_terminator(self, *args, **kwargs) -> Any: ... - def insert(self, operation: _OperationBase) -> None: ... - def __enter__(self) -> object: ... + def __init__(self, beforeOperation: "_OperationBase") -> None: ... + @staticmethod + def at_block_begin(block: "Block") -> "InsertionPoint": ... + @staticmethod + def at_block_terminator(block: "Block") -> "InsertionPoint": ... + def insert(self, operation: "_OperationBase") -> None: ... + def __enter__(self) -> "InsertionPoint": ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property - def block(self) -> Block: ... + def block(self) -> "Block": ... +# TODO: Auto-generated. Audit and fix. class IntegerAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -554,6 +583,7 @@ class IntegerAttr(Attribute): @property def value(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class IntegerSet: def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> IntegerSet: ... @@ -585,6 +615,7 @@ class IntegerSet: @property def n_symbols(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class IntegerSetConstraint: def __init__(self, *args, **kwargs) -> None: ... @property @@ -592,6 +623,7 @@ class IntegerSetConstraint: @property def is_eq(self) -> bool: ... +# TODO: Auto-generated. Audit and fix. class IntegerSetConstraintList: def __init__(self, *args, **kwargs) -> None: ... def __add__(self, arg0: IntegerSetConstraintList) -> List[IntegerSetConstraint]: ... @@ -601,6 +633,7 @@ class IntegerSetConstraintList: def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... def __len__(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class IntegerType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get_signed(self, *args, **kwargs) -> Any: ... @@ -617,25 +650,29 @@ class IntegerType(Type): def width(self) -> int: ... class Location: - current: ClassVar[Location] = ... # read-only + current: ClassVar["Location"] = ... # read-only __hash__: ClassVar[None] = ... - def __init__(self, *args, **kwargs) -> None: ... - def _CAPICreate(self) -> Location: ... - def callsite(self, *args, **kwargs) -> Any: ... - def file(self, *args, **kwargs) -> Any: ... - def name(self, *args, **kwargs) -> Any: ... - def unknown(self, *args, **kwargs) -> Any: ... - def __enter__(self) -> object: ... + def _CAPICreate(self) -> "Location": ... + @staticmethod + def callsite(callee: "Location", frames: Sequence["Location"], context: Optional["Context"] = None) -> "Location": ... + @staticmethod + def file(filename: str, line: int, col: int, context: Optional["Context"] = None) -> "Location": ... + @staticmethod + def name(name: str, childLoc: Optional["Location"] = None, context: Optional["Context"] = None) -> "Location": ... + @staticmethod + def unknown(context: Optional["Context"] = None) -> Any: ... + def __enter__(self) -> "Location": ... @overload - def __eq__(self, arg0: Location) -> bool: ... + def __eq__(self, arg0: "Location") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> "Context": ... +# TODO: Auto-generated. Audit and fix. class MemRefType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -648,50 +685,48 @@ class MemRefType(ShapedType): def memory_space(self) -> Attribute: ... class Module: - def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> object: ... - def create(self, *args, **kwargs) -> Any: ... + def create(loc: Optional["Location"] = None) -> "Module": ... def dump(self) -> None: ... - def parse(self, *args, **kwargs) -> Any: ... + @staticmethod + def parse(asm: str, context: Optional[Context] = None) -> "Module": ... @property def _CAPIPtr(self) -> object: ... @property - def body(self) -> Any: ... + def body(self) -> "Block": ... @property def context(self) -> object: ... @property - def operation(self) -> object: ... + def operation(self) -> "_OperationBase": ... class NamedAttribute: - def __init__(self, *args, **kwargs) -> None: ... @property - def attr(self) -> Attribute: ... + def attr(self) -> "Attribute": ... @property def name(self) -> str: ... +# TODO: Auto-generated. Audit and fix. class NoneType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... class OpAttributeMap: - def __init__(self, *args, **kwargs) -> None: ... def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... @overload - def __getitem__(self, arg0: str) -> Attribute: ... + def __getitem__(self, arg0: str) -> "Attribute": ... @overload - def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __getitem__(self, arg0: int) -> "NamedAttribute": ... def __len__(self) -> int: ... - def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... + def __setitem__(self, arg0: str, arg1: "Attribute") -> None: ... class OpOperandList: - def __init__(self, *args, **kwargs) -> None: ... - def __add__(self, arg0: OpOperandList) -> List[Value]: ... + def __add__(self, arg0: "OpOperandList") -> List[Value]: ... @overload def __getitem__(self, arg0: int) -> Value: ... @overload - def __getitem__(self, arg0: slice) -> OpOperandList: ... + def __getitem__(self, arg0: slice) -> "OpOperandList": ... def __len__(self) -> int: ... def __setitem__(self, arg0: int, arg1: Value) -> None: ... @@ -699,60 +734,70 @@ class OpResult(Value): def __init__(self, value: Value) -> None: ... def isinstance(self, *args, **kwargs) -> Any: ... @property - def owner(self) -> object: ... + def owner(self) -> "_OperationBase": ... @property def result_number(self) -> int: ... class OpResultList: - def __init__(self, *args, **kwargs) -> None: ... - def __add__(self, arg0: OpResultList) -> List[OpResult]: ... + def __add__(self, arg0: "OpResultList") -> List["OpResult"]: ... @overload - def __getitem__(self, arg0: int) -> OpResult: ... + def __getitem__(self, arg0: int) -> "OpResult": ... @overload - def __getitem__(self, arg0: slice) -> OpResultList: ... + def __getitem__(self, arg0: slice) -> "OpResultList": ... def __len__(self) -> int: ... @property - def types(self) -> List[Type]: ... + def types(self) -> List["Type"]: ... class OpView(_OperationBase): _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... _ODS_REGIONS: ClassVar[tuple] = ... _ODS_RESULT_SEGMENTS: ClassVar[None] = ... - def __init__(self, operation: object) -> None: ... + def __init__(self, operation: "_OperationBase") -> None: ... @classmethod - def build_generic(self, *args, **kwargs) -> Any: ... + def build_generic(cls, results: Optional[Sequence["Type"]] = None, + operands: Optional[Sequence["Value"]] = None, + attributes: Optional[Dict[str, "Attribute"]] = None, + successors: Optional[Sequence["Block"]] = None, + regions: Optional[int] = None, + loc: Optional["Location"] = None, + ip: Optional["InsertionPoint"] = None) -> "_OperationBase": ... @property - def context(self) -> object: ... + def context(self) -> "Context": ... @property - def operation(self) -> object: ... + def operation(self) -> "_OperationBase": ... class Operation(_OperationBase): - def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> object: ... - def create(self, *args, **kwargs) -> Any: ... + @staticmethod + def create(name: str, results: Optional[Sequence["Type"]] = None, + operands: Optional[Sequence["Value"]] = None, + attributes: Optional[Dict[str, "Attribute"]] = None, + successors: Optional[Sequence["Block"]] = None, + regions: int = 0, + loc: Optional["Location"] = None, + ip: Optional["InsertionPoint"] = None) -> "_OperationBase": ... def erase(self) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> "Context": ... @property def name(self) -> str: ... @property - def opview(self) -> object: ... + def opview(self) -> "OpView": ... @property - def parent(self) -> object: ... + def parent(self) -> Optional["_OperationBase"]: ... class OperationIterator: - def __init__(self, *args, **kwargs) -> None: ... - def __iter__(self) -> OperationIterator: ... - def __next__(self) -> object: ... + def __iter__(self) -> "OperationIterator": ... + def __next__(self) -> "OpView": ... class OperationList: - def __init__(self, *args, **kwargs) -> None: ... - def __getitem__(self, arg0: int) -> object: ... - def __iter__(self) -> OperationIterator: ... + def __getitem__(self, arg0: int) -> "OpView": ... + def __iter__(self) -> "OperationIterator": ... def __len__(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class RankedTensorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -762,27 +807,25 @@ class RankedTensorType(ShapedType): class Region: __hash__: ClassVar[None] = ... - def __init__(self, *args, **kwargs) -> None: ... @overload - def __eq__(self, arg0: Region) -> bool: ... + def __eq__(self, arg0: "Region") -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... - def __iter__(self) -> Any: ... + def __iter__(self) -> "BlockIterator": ... @property - def blocks(self) -> Any: ... + def blocks(self) -> "BlockList": ... @property - def owner(self) -> object: ... + def owner(self) -> "OpView": ... class RegionIterator: - def __init__(self, *args, **kwargs) -> None: ... - def __iter__(self) -> RegionIterator: ... - def __next__(self) -> Region: ... + def __iter__(self) -> "RegionIterator": ... + def __next__(self) -> "Region": ... class RegionSequence: - def __init__(self, *args, **kwargs) -> None: ... - def __getitem__(self, arg0: int) -> Region: ... + def __getitem__(self, arg0: int) -> "Region": ... def __len__(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class ShapedType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get_dim_size(self, dim: int) -> int: ... @@ -801,6 +844,7 @@ class ShapedType(Type): @property def shape(self) -> List[int]: ... +# TODO: Auto-generated. Audit and fix. class StringAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -829,8 +873,9 @@ class SymbolTable: def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None) -> None: ... def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... - def __getitem__(self, arg0: str) -> object: ... + def __getitem__(self, arg0: str) -> "OpView": ... +# TODO: Auto-generated. Audit and fix. class TupleType(Type): def __init__(self, cast_from_type: Type) -> None: ... def get_tuple(self, *args, **kwargs) -> Any: ... @@ -839,6 +884,7 @@ class TupleType(Type): @property def num_types(self) -> int: ... +# TODO: Auto-generated. Audit and fix. class TypeAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -848,6 +894,7 @@ class TypeAttr(Attribute): @property def value(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class UnitAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -855,6 +902,7 @@ class UnitAttr(Attribute): @property def type(self) -> Type: ... +# TODO: Auto-generated. Audit and fix. class UnrankedMemRefType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -862,11 +910,13 @@ class UnrankedMemRefType(ShapedType): @property def memory_space(self) -> Attribute: ... +# TODO: Auto-generated. Audit and fix. class UnrankedTensorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... def isinstance(self, *args, **kwargs) -> Any: ... +# TODO: Auto-generated. Audit and fix. class VectorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... def get(self, *args, **kwargs) -> Any: ... @@ -874,4 +924,3 @@ class VectorType(ShapedType): class _GlobalDebug: flag: ClassVar[bool] = ... - def __init__(self, *args, **kwargs) -> None: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 7003bf0f0..728f46418 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -4,7 +4,7 @@ # * Relative imports for cross-module references. # * Add __all__ -from typing import Any +from typing import Any, Optional from . import ir as _ir @@ -13,12 +13,13 @@ __all__ = [ ] class PassManager: - def __init__(self, context: _ir.Context = ...) -> None: ... + def __init__(self, context: Optional[_ir.Context] = None) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... def enable_ir_printing(self) -> None: ... def enable_verifier(self, enable: bool) -> None: ... - def parse(self, *args, **kwargs) -> Any: ... + @staticmethod + def parse(pipeline: str, context: Optional[_ir.Context] = None) -> "PassManager": ... def run(self, module: _ir.Module) -> None: ... @property def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi index 50ff6c5b1..893dab8a4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -4,7 +4,7 @@ # * Relative imports for cross-module references. # * Add __all__ -from typing import List +from typing import List, Sequence from ._mlir import ir as _ir @@ -13,7 +13,7 @@ __all__ = [ ] class ExecutionEngine: - def __init__(self, module: _ir.Module, opt_level: int = ..., shared_libs: List[str] = ...) -> None: ... + def __init__(self, module: _ir.Module, opt_level: int = 2, shared_libs: Sequence[str] = ...) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... def dump_to_object_file(self, file_name: str) -> None: ... From 77d78f8000405c4eec6cd6c85927663921afe38c Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 7 Dec 2021 18:27:58 +0000 Subject: [PATCH 178/915] Adjust "end namespace" comment in MLIR to match new agree'd coding style See D115115 and this mailing list discussion: https://lists.llvm.org/pipermail/llvm-dev/2021-December/154199.html Differential Revision: https://reviews.llvm.org/D115309 --- mlir/include/mlir/CAPI/Utils.h | 4 ++-- mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 2 +- mlir/lib/Bindings/Python/IRAffine.cpp | 2 +- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/lib/Bindings/Python/Pass.cpp | 2 +- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/CAPI/Utils.h b/mlir/include/mlir/CAPI/Utils.h index c2e43850c..000516e7a 100644 --- a/mlir/include/mlir/CAPI/Utils.h +++ b/mlir/include/mlir/CAPI/Utils.h @@ -45,7 +45,7 @@ class CallbackOstream : public llvm::raw_ostream { void *opaqueData; uint64_t pos; }; -} // end namespace detail -} // end namespace mlir +} // namespace detail +} // namespace mlir #endif // MLIR_CAPI_UTILS_H diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index b5a0f84d4..814209197 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -62,7 +62,7 @@ class PyExecutionEngine { std::vector referencedObjects; }; -} // anonymous namespace +} // namespace /// Create the `mlir.execution_engine` module here. PYBIND11_MODULE(_mlirExecutionEngine, m) { diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 272de0d7a..faf01e5c5 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -397,7 +397,7 @@ class PyAffineMapExprList private: PyAffineMap affineMap; }; -} // end namespace +} // namespace bool PyAffineMap::operator==(const PyAffineMap &other) { return mlirAffineMapEqual(affineMap, other.affineMap); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0d3491433..cd5755248 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2012,7 +2012,7 @@ class PyOpAttributeMap { PyOperationRef operation; }; -} // end namespace +} // namespace //------------------------------------------------------------------------------ // Populates the core exports of the 'ir' submodule. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 2c38a3a25..dba2231a1 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -48,7 +48,7 @@ class PyPassManager { MlirPassManager passManager; }; -} // anonymous namespace +} // namespace /// Create the `mlir.passmanager` here. void mlir::python::populatePassManagerSubmodule(py::module &m) { From 1b8aa80a160dfd6896293a614cefc815e269ad72 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 11 Dec 2021 10:12:29 -0800 Subject: [PATCH 179/915] [mlir][python] Add fused location --- mlir/lib/Bindings/Python/IRCore.cpp | 20 ++++++++++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2 ++ 2 files changed, 22 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cd5755248..3640a15e3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -47,6 +47,9 @@ static const char kContextGetCallSiteLocationDocstring[] = static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; +static const char kContextGetFusedLocationDocstring[] = + R"(Gets a Location representing a fused location with optional metadata)"; + static const char kContextGetNameLocationDocString[] = R"(Gets a Location representing a named location with optional child location)"; @@ -2197,6 +2200,23 @@ void mlir::python::populateIRCore(py::module &m) { }, py::arg("filename"), py::arg("line"), py::arg("col"), py::arg("context") = py::none(), kContextGetFileLocationDocstring) + .def_static( + "fused", + [](const std::vector &pyLocations, llvm::Optional metadata, + DefaultingPyMlirContext context) { + if (pyLocations.empty()) + throw py::value_error("No locations provided"); + llvm::SmallVector locations; + locations.reserve(pyLocations.size()); + for (auto &pyLocation : pyLocations) + locations.push_back(pyLocation.get()); + MlirLocation location = mlirLocationFusedGet( + context->get(), locations.size(), locations.data(), + metadata ? metadata->get() : MlirAttribute{0}); + return PyLocation(context->getRef(), location); + }, + py::arg("locations"), py::arg("metadata") = py::none(), + py::arg("context") = py::none(), kContextGetFusedLocationDocstring) .def_static( "name", [](std::string name, llvm::Optional childLoc, diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index e1a84ddfd..e61e34a17 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -658,6 +658,8 @@ class Location: @staticmethod def file(filename: str, line: int, col: int, context: Optional["Context"] = None) -> "Location": ... @staticmethod + def fused(locations: Sequence["Location"], metadata: Optional["Attribute"] = None, context: Optional["Context"] = None) -> "Location": ... + @staticmethod def name(name: str, childLoc: Optional["Location"] = None, context: Optional["Context"] = None) -> "Location": ... @staticmethod def unknown(context: Optional["Context"] = None) -> Any: ... From 35d3f69ca7862d294c4812e2c88509b48664a1e5 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 13 Dec 2021 21:01:42 +0000 Subject: [PATCH 180/915] [mlir][ExecutionEngine] Fix native dependencies for AsmParser and Printer This is a post-commit fix for https://reviews.llvm.org/D114338 which was landed as https://reviews.llvm.org/rG050cc1cd6e6882eadba6e5ea7b588ca0b8aa1b12 Differential Revision: https://reviews.llvm.org/D115666 --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 48f45f8c0..105ce24dd 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,3 +1,8 @@ +set(LLVM_LINK_COMPONENTS + nativecodegen + native +) + # Main API shared library. add_mlir_upstream_c_api_library(MLIRCAPIExecutionEngine ExecutionEngine.cpp From ba008ff93a65659a96b4775a64770c65167b4119 Mon Sep 17 00:00:00 2001 From: gysit Date: Wed, 15 Dec 2021 12:14:35 +0000 Subject: [PATCH 181/915] [mlir][linalg] Replace LinalgOps.h and LinalgTypes.h by a single header. After removing the range type, Linalg does not define any type. The revision thus consolidates the LinalgOps.h and LinalgTypes.h into a single Linalg.h header. Additionally, LinalgTypes.cpp is renamed to LinalgDialect.cpp to follow the convention adopted by other dialects such as the tensor dialect. Depends On D115727 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D115728 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 902599f3b..dd796167b 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -8,7 +8,7 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" using namespace mlir; using namespace mlir::linalg; From 9ad3eefcde5c6457d203bb5fa99110cfa65e4536 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 20 Dec 2021 19:45:05 +0000 Subject: [PATCH 182/915] Fix clang-tidy issues in mlir/ (NFC) Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D115956 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 10 +++++----- mlir/lib/Bindings/Python/IRCore.cpp | 6 +++--- mlir/lib/Bindings/Python/IRModule.cpp | 10 ++++------ mlir/lib/Bindings/Python/PybindUtils.cpp | 2 -- mlir/lib/Bindings/Python/Transforms/Transforms.cpp | 2 -- mlir/lib/CAPI/IR/IR.cpp | 2 +- 6 files changed, 13 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 17b3b34a2..eed6369ac 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -17,7 +17,6 @@ namespace py = pybind11; using namespace mlir; using namespace mlir::python; -using llvm::None; using llvm::Optional; using llvm::SmallVector; using llvm::Twine; @@ -510,7 +509,8 @@ class PyDenseElementsAttribute if (mlirTypeIsAF32(elementType)) { // f32 return bufferInfo(shapedType); - } else if (mlirTypeIsAF64(elementType)) { + } + if (mlirTypeIsAF64(elementType)) { // f64 return bufferInfo(shapedType); } else if (mlirTypeIsAF16(elementType)) { @@ -712,12 +712,12 @@ class PyDictAttribute : public PyConcreteAttribute { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); for (auto &it : attributes) { - auto &mlir_attr = it.second.cast(); + auto &mlirAttr = it.second.cast(); auto name = it.first.cast(); mlirNamedAttributes.push_back(mlirNamedAttributeGet( - mlirIdentifierGet(mlirAttributeGetContext(mlir_attr), + mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), toMlirStringRef(name)), - mlir_attr)); + mlirAttr)); } MlirAttribute attr = mlirDictionaryAttrGet(context->get(), mlirNamedAttributes.size(), diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3640a15e3..864144226 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1267,7 +1267,7 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto operandValue = py::cast(std::get<0>(it.value())); + auto *operandValue = py::cast(std::get<0>(it.value())); if (operandValue) { operands.push_back(operandValue); operandSegmentLengths.push_back(1); @@ -2286,10 +2286,10 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly( "body", [](PyModule &self) { - PyOperationRef module_op = PyOperation::forOperation( + PyOperationRef moduleOp = PyOperation::forOperation( self.getContext(), mlirModuleGetOperation(self.get()), self.getRef().releaseObject()); - PyBlock returnBlock(module_op, mlirModuleGetBody(self.get())); + PyBlock returnBlock(moduleOp, mlirModuleGetBody(self.get())); return returnBlock; }, "Return the block for this module") diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 9f853eb92..7008e54bd 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -51,9 +51,8 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } catch (py::error_already_set &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; - } else { - throw; } + throw; } break; } @@ -136,11 +135,10 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { // Positive cache. rawOpViewClassMapCache[operationName] = foundIt->second; return foundIt->second; - } else { - // Negative cache. - rawOpViewClassMap[operationName] = py::none(); - return llvm::None; } + // Negative cache. + rawOpViewClassMap[operationName] = py::none(); + return llvm::None; } } diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp index bd80b8c14..d243307f1 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ b/mlir/lib/Bindings/Python/PybindUtils.cpp @@ -8,8 +8,6 @@ #include "PybindUtils.h" -namespace py = pybind11; - pybind11::error_already_set mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { auto messageStr = message.str(); diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp index 46c469192..944b191bc 100644 --- a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp +++ b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp @@ -10,8 +10,6 @@ #include -namespace py = pybind11; - // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 424bbae17..955f5e0c1 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -818,7 +818,7 @@ void mlirSymbolTableErase(MlirSymbolTable symbolTable, MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirStringRef newSymbol, MlirOperation from) { - auto cppFrom = unwrap(from); + auto *cppFrom = unwrap(from); auto *context = cppFrom->getContext(); auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context); auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context); From 2dd8776f5566d19fe8b82811cc20a74e9912f328 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 17 Dec 2021 18:45:10 +0000 Subject: [PATCH 183/915] Fix clang-tidy issues in mlir/ (NFC) Differential Revision: https://reviews.llvm.org/D115956 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index eed6369ac..f1206617d 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -513,11 +513,13 @@ class PyDenseElementsAttribute if (mlirTypeIsAF64(elementType)) { // f64 return bufferInfo(shapedType); - } else if (mlirTypeIsAF16(elementType)) { + } + if (mlirTypeIsAF16(elementType)) { // f16 return bufferInfo(shapedType, "e"); - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 32) { + } + if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 32) { if (mlirIntegerTypeIsSignless(elementType) || mlirIntegerTypeIsSigned(elementType)) { // i32 From 299878c11186751f691ce0b1bcb8ab6868890906 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 22 Dec 2021 00:19:53 +0000 Subject: [PATCH 184/915] Fix more clang-tidy cleanups in mlir/ (NFC) --- mlir/lib/Bindings/Python/IRAffine.cpp | 2 +- mlir/lib/Bindings/Python/IRAttributes.cpp | 12 ++++++++---- mlir/lib/Bindings/Python/IRInterfaces.cpp | 3 +-- mlir/lib/Bindings/Python/IRModule.cpp | 2 +- mlir/lib/CAPI/IR/Diagnostics.cpp | 2 +- 5 files changed, 12 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index faf01e5c5..c7cdc8243 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -676,7 +676,7 @@ void mlir::python::populateIRAffine(py::module &m) { std::vector res; res.reserve(compressed.size()); for (auto m : compressed) - res.push_back(PyAffineMap(context->getRef(), m)); + res.emplace_back(context->getRef(), m); return res; }) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index f1206617d..56d16b337 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -524,7 +524,8 @@ class PyDenseElementsAttribute mlirIntegerTypeIsSigned(elementType)) { // i32 return bufferInfo(shapedType); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { + } + if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i32 return bufferInfo(shapedType); } @@ -534,7 +535,8 @@ class PyDenseElementsAttribute mlirIntegerTypeIsSigned(elementType)) { // i64 return bufferInfo(shapedType); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { + } + if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i64 return bufferInfo(shapedType); } @@ -544,7 +546,8 @@ class PyDenseElementsAttribute mlirIntegerTypeIsSigned(elementType)) { // i8 return bufferInfo(shapedType); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { + } + if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i8 return bufferInfo(shapedType); } @@ -554,7 +557,8 @@ class PyDenseElementsAttribute mlirIntegerTypeIsSigned(elementType)) { // i16 return bufferInfo(shapedType); - } else if (mlirIntegerTypeIsUnsigned(elementType)) { + } + if (mlirIntegerTypeIsUnsigned(elementType)) { // unsigned i16 return bufferInfo(shapedType); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index c3d41c4d8..564f36b9d 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -175,8 +175,7 @@ class PyInferTypeOpInterface auto *data = static_cast(userData); data->inferredTypes.reserve(data->inferredTypes.size() + nTypes); for (intptr_t i = 0; i < nTypes; ++i) { - data->inferredTypes.push_back( - PyType(data->pyMlirContext.getRef(), types[i])); + data->inferredTypes.emplace_back(data->pyMlirContext.getRef(), types[i]); } } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 7008e54bd..633ffe4e1 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -29,7 +29,7 @@ PyGlobals::PyGlobals() { instance = this; // The default search path include {mlir.}dialects, where {mlir.} is the // package prefix configured at compile time. - dialectSearchPrefixes.push_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); + dialectSearchPrefixes.emplace_back(MAKE_MLIR_PYTHON_QUALNAME("dialects")); } PyGlobals::~PyGlobals() { instance = nullptr; } diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp index 2ed05a5a0..40639c7ba 100644 --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -57,7 +57,7 @@ MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( MlirContext context, MlirDiagnosticHandler handler, void *userData, void (*deleteUserData)(void *)) { assert(handler && "unexpected null diagnostic handler"); - if (deleteUserData == NULL) + if (deleteUserData == nullptr) deleteUserData = deleteUserDataNoop; std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = From c65389f9c64b91f3b45dcb09c7d192a9f9029502 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 2 Jan 2022 01:26:44 +0000 Subject: [PATCH 185/915] Apply clang-tidy fixes for performance-unnecessary-value-param to MLIR (NFC) Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D116250 --- .../Bindings/Python/DialectSparseTensor.cpp | 2 +- mlir/lib/Bindings/Python/Dialects.h | 2 +- .../Bindings/Python/ExecutionEngineModule.cpp | 2 +- mlir/lib/Bindings/Python/IRAffine.cpp | 25 ++++++----- mlir/lib/Bindings/Python/IRAttributes.cpp | 6 ++- mlir/lib/Bindings/Python/IRCore.cpp | 41 +++++++++++-------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 4 +- mlir/lib/Bindings/Python/IRModule.h | 25 ++++++----- 8 files changed, 62 insertions(+), 45 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 6afd0815d..7de0b8156 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -17,7 +17,7 @@ using namespace mlir; using namespace mlir::python::adaptors; void mlir::python::populateDialectSparseTensorSubmodule( - py::module m, const py::module &irModule) { + const py::module &m, const py::module &irModule) { auto attributeClass = irModule.attr("Attribute"); py::enum_(m, "DimLevelType", py::module_local()) diff --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h index 301d53927..c1725074c 100644 --- a/mlir/lib/Bindings/Python/Dialects.h +++ b/mlir/lib/Bindings/Python/Dialects.h @@ -15,7 +15,7 @@ namespace mlir { namespace python { void populateDialectLinalgSubmodule(pybind11::module m); -void populateDialectSparseTensorSubmodule(pybind11::module m, +void populateDialectSparseTensorSubmodule(const pybind11::module &m, const pybind11::module &irModule); } // namespace python diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 814209197..901690018 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -42,7 +42,7 @@ class PyExecutionEngine { // Add an object to the list of referenced objects whose lifetime must exceed // those of the ExecutionEngine. - void addReferencedObject(pybind11::object obj) { + void addReferencedObject(const pybind11::object &obj) { referencedObjects.push_back(obj); } diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index c7cdc8243..16c7ca335 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "IRModule.h" #include "PybindUtils.h" @@ -30,7 +32,8 @@ static const char kDumpDocstring[] = /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template -static void pyListToVector(py::list list, llvm::SmallVectorImpl &result, +static void pyListToVector(const py::list &list, + llvm::SmallVectorImpl &result, StringRef action) { result.reserve(py::len(list)); for (py::handle item : list) { @@ -203,7 +206,7 @@ class PyAffineAddExpr static constexpr const char *pyClassName = "AffineAddExpr"; using PyConcreteAffineExpr::PyConcreteAffineExpr; - static PyAffineAddExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + static PyAffineAddExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { MlirAffineExpr expr = mlirAffineAddExprGet(lhs, rhs); return PyAffineAddExpr(lhs.getContext(), expr); } @@ -232,7 +235,7 @@ class PyAffineMulExpr static constexpr const char *pyClassName = "AffineMulExpr"; using PyConcreteAffineExpr::PyConcreteAffineExpr; - static PyAffineMulExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + static PyAffineMulExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { MlirAffineExpr expr = mlirAffineMulExprGet(lhs, rhs); return PyAffineMulExpr(lhs.getContext(), expr); } @@ -261,7 +264,7 @@ class PyAffineModExpr static constexpr const char *pyClassName = "AffineModExpr"; using PyConcreteAffineExpr::PyConcreteAffineExpr; - static PyAffineModExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + static PyAffineModExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { MlirAffineExpr expr = mlirAffineModExprGet(lhs, rhs); return PyAffineModExpr(lhs.getContext(), expr); } @@ -290,7 +293,7 @@ class PyAffineFloorDivExpr static constexpr const char *pyClassName = "AffineFloorDivExpr"; using PyConcreteAffineExpr::PyConcreteAffineExpr; - static PyAffineFloorDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + static PyAffineFloorDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { MlirAffineExpr expr = mlirAffineFloorDivExprGet(lhs, rhs); return PyAffineFloorDivExpr(lhs.getContext(), expr); } @@ -319,7 +322,7 @@ class PyAffineCeilDivExpr static constexpr const char *pyClassName = "AffineCeilDivExpr"; using PyConcreteAffineExpr::PyConcreteAffineExpr; - static PyAffineCeilDivExpr get(PyAffineExpr lhs, PyAffineExpr rhs) { + static PyAffineCeilDivExpr get(PyAffineExpr lhs, const PyAffineExpr &rhs) { MlirAffineExpr expr = mlirAffineCeilDivExprGet(lhs, rhs); return PyAffineCeilDivExpr(lhs.getContext(), expr); } @@ -375,7 +378,7 @@ class PyAffineMapExprList public: static constexpr const char *pyClassName = "AffineExprList"; - PyAffineMapExprList(PyAffineMap map, intptr_t startIndex = 0, + PyAffineMapExprList(const PyAffineMap &map, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) : Sliceable(startIndex, length == -1 ? mlirAffineMapGetNumResults(map) : length, @@ -423,7 +426,8 @@ namespace { class PyIntegerSetConstraint { public: - PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) : set(set), pos(pos) {} + PyIntegerSetConstraint(PyIntegerSet set, intptr_t pos) + : set(std::move(set)), pos(pos) {} PyAffineExpr getExpr() { return PyAffineExpr(set.getContext(), @@ -449,7 +453,7 @@ class PyIntegerSetConstraintList public: static constexpr const char *pyClassName = "IntegerSetConstraintList"; - PyIntegerSetConstraintList(PyIntegerSet set, intptr_t startIndex = 0, + PyIntegerSetConstraintList(const PyIntegerSet &set, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) : Sliceable(startIndex, length == -1 ? mlirIntegerSetGetNumConstraints(set) : length, @@ -692,7 +696,8 @@ void mlir::python::populateIRAffine(py::module &m) { DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( - exprs, affineExprs, "attempting to create an AffineMap"); + std::move(exprs), affineExprs, + "attempting to create an AffineMap"); MlirAffineMap map = mlirAffineMapGet(context->get(), dimCount, symbolCount, affineExprs.size(), affineExprs.data()); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 56d16b337..fd44ffe6b 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "IRModule.h" #include "PybindUtils.h" @@ -116,7 +118,7 @@ class PyArrayAttribute : public PyConcreteAttribute { class PyArrayAttributeIterator { public: - PyArrayAttributeIterator(PyAttribute attr) : attr(attr) {} + PyArrayAttributeIterator(PyAttribute attr) : attr(std::move(attr)) {} PyArrayAttributeIterator &dunderIter() { return *this; } @@ -459,7 +461,7 @@ class PyDenseElementsAttribute arrayInfo.format); } - static PyDenseElementsAttribute getSplat(PyType shapedType, + static PyDenseElementsAttribute getSplat(const PyType &shapedType, PyAttribute &elementAttr) { auto contextWrapper = PyMlirContext::forContext(mlirTypeGetContext(shapedType)); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 864144226..ccdd159fd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -21,6 +21,8 @@ #include "llvm/ADT/SmallVector.h" #include +#include + namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -176,7 +178,7 @@ static MlirStringRef toMlirStringRef(const std::string &s) { struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(py::object) { return mlirIsGlobalDebugEnabled(); } + static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } static void bind(py::module &m) { // Debug flags. @@ -320,7 +322,7 @@ class PyBlockList { throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); } - PyBlock appendBlock(py::args pyArgTypes) { + PyBlock appendBlock(const py::args &pyArgTypes) { operation->checkValid(); llvm::SmallVector argTypes; argTypes.reserve(pyArgTypes.size()); @@ -503,9 +505,9 @@ pybind11::object PyMlirContext::contextEnter() { return PyThreadContextEntry::pushContext(*this); } -void PyMlirContext::contextExit(pybind11::object excType, - pybind11::object excVal, - pybind11::object excTb) { +void PyMlirContext::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popContext(*this); } @@ -689,8 +691,9 @@ py::object PyLocation::contextEnter() { return PyThreadContextEntry::pushLocation(*this); } -void PyLocation::contextExit(py::object excType, py::object excVal, - py::object excTb) { +void PyLocation::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popLocation(*this); } @@ -945,11 +948,11 @@ py::object PyOperation::createFromCapsule(py::object capsule) { } py::object PyOperation::create( - std::string name, llvm::Optional> results, + const std::string &name, llvm::Optional> results, llvm::Optional> operands, llvm::Optional attributes, llvm::Optional> successors, int regions, - DefaultingPyLocation location, py::object maybeIp) { + DefaultingPyLocation location, const py::object &maybeIp) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1105,7 +1108,7 @@ void PyOperation::erase() { //------------------------------------------------------------------------------ py::object -PyOpView::buildGeneric(py::object cls, py::list resultTypeList, +PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, py::list operandList, llvm::Optional attributes, llvm::Optional> successors, @@ -1359,16 +1362,17 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList, /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), - /*regions=*/*regions, location, maybeIp); + /*regions=*/*regions, location, + std::move(maybeIp)); } -PyOpView::PyOpView(py::object operationObject) +PyOpView::PyOpView(const py::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. : operation(py::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} -py::object PyOpView::createRawSubclass(py::object userClass) { +py::object PyOpView::createRawSubclass(const py::object &userClass) { // This is... a little gross. The typical pattern is to have a pure python // class that extends OpView like: // class AddFOp(_cext.ir.OpView): @@ -1465,9 +1469,9 @@ py::object PyInsertionPoint::contextEnter() { return PyThreadContextEntry::pushInsertionPoint(*this); } -void PyInsertionPoint::contextExit(pybind11::object excType, - pybind11::object excVal, - pybind11::object excTb) { +void PyInsertionPoint::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popInsertionPoint(*this); } @@ -1954,7 +1958,8 @@ class PyOpResultList : public Sliceable { /// attributes, or by index, producing named attributes. class PyOpAttributeMap { public: - PyOpAttributeMap(PyOperationRef operation) : operation(operation) {} + PyOpAttributeMap(PyOperationRef operation) + : operation(std::move(operation)) {} PyAttribute dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), @@ -1979,7 +1984,7 @@ class PyOpAttributeMap { mlirIdentifierStr(namedAttr.name).length)); } - void dunderSetItem(const std::string &name, PyAttribute attr) { + void dunderSetItem(const std::string &name, const PyAttribute &attr) { mlirOperationSetAttributeByName(operation->get(), toMlirStringRef(name), attr); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 564f36b9d..1fc66fef4 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include + #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Interfaces.h" @@ -58,7 +60,7 @@ class PyConcreteOpInterface { /// operation or a subclass of OpView. In the latter case, only the static /// methods of the interface are accessible to the caller. PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) - : obj(object) { + : obj(std::move(object)) { try { operation = &py::cast(obj); } catch (py::cast_error &err) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d5e8eb4ae..df4aaebf3 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -203,8 +203,9 @@ class PyMlirContext { /// Enter and exit the context manager. pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); private: PyMlirContext(MlirContext context); @@ -316,8 +317,9 @@ class PyLocation : public BaseContextObject { /// Enter and exit the context manager. pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); /// Gets a capsule wrapping the void* within the MlirLocation. pybind11::object getCapsule(); @@ -482,11 +484,11 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an operation. See corresponding python docstring. static pybind11::object - create(std::string name, llvm::Optional> results, + create(const std::string &name, llvm::Optional> results, llvm::Optional> operands, llvm::Optional attributes, llvm::Optional> successors, int regions, - DefaultingPyLocation location, pybind11::object ip); + DefaultingPyLocation location, const pybind11::object &ip); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); @@ -524,15 +526,15 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// python types. class PyOpView : public PyOperationBase { public: - PyOpView(pybind11::object operationObject); + PyOpView(const pybind11::object &operationObject); PyOperation &getOperation() override { return operation; } - static pybind11::object createRawSubclass(pybind11::object userClass); + static pybind11::object createRawSubclass(const pybind11::object &userClass); pybind11::object getOperationObject() { return operationObject; } static pybind11::object - buildGeneric(pybind11::object cls, pybind11::list resultTypeList, + buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, pybind11::list operandList, llvm::Optional attributes, llvm::Optional> successors, @@ -607,8 +609,9 @@ class PyInsertionPoint { /// Enter and exit the context manager. pybind11::object contextEnter(); - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); PyBlock &getBlock() { return block; } From 5c6e621548a2a05882649f092b1bace580c01bcf Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 2 Jan 2022 22:02:14 +0000 Subject: [PATCH 186/915] Apply clang-tidy fixes for performance-for-range-copy to MLIR (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index ccdd159fd..be2abcdd5 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1155,7 +1155,7 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, resultTypes.reserve(resultTypeList.size()); if (resultSegmentSpecObj.is_none()) { // Non-variadic result unpacking. - for (auto it : llvm::enumerate(resultTypeList)) { + for (const auto &it : llvm::enumerate(resultTypeList)) { try { resultTypes.push_back(py::cast(it.value())); if (!resultTypes.back()) @@ -1179,7 +1179,7 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, .str()); } resultSegmentLengths.reserve(resultTypeList.size()); - for (auto it : + for (const auto &it : llvm::enumerate(llvm::zip(resultTypeList, resultSegmentSpec))) { int segmentSpec = std::get<1>(it.value()); if (segmentSpec == 1 || segmentSpec == 0) { @@ -1240,7 +1240,7 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, operands.reserve(operands.size()); if (operandSegmentSpecObj.is_none()) { // Non-sized operand unpacking. - for (auto it : llvm::enumerate(operandList)) { + for (const auto &it : llvm::enumerate(operandList)) { try { operands.push_back(py::cast(it.value())); if (!operands.back()) @@ -1264,7 +1264,7 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, .str()); } operandSegmentLengths.reserve(operandList.size()); - for (auto it : + for (const auto &it : llvm::enumerate(llvm::zip(operandList, operandSegmentSpec))) { int segmentSpec = std::get<1>(it.value()); if (segmentSpec == 1 || segmentSpec == 0) { From f408942206b37e1ac53513cd13c7f12c2f72041b Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 2 Jan 2022 22:02:18 +0000 Subject: [PATCH 187/915] Apply clang-tidy fixes for performance-move-const-arg to MLIR (NFC) --- mlir/lib/Bindings/Python/IRAffine.cpp | 3 +-- mlir/lib/Bindings/Python/IRCore.cpp | 7 +++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 16c7ca335..0da936e85 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -696,8 +696,7 @@ void mlir::python::populateIRAffine(py::module &m) { DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( - std::move(exprs), affineExprs, - "attempting to create an AffineMap"); + exprs, affineExprs, "attempting to create an AffineMap"); MlirAffineMap map = mlirAffineMapGet(context->get(), dimCount, symbolCount, affineExprs.size(), affineExprs.data()); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index be2abcdd5..b9d31b27b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1357,13 +1357,12 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, } // Delegate to create. - return PyOperation::create(std::move(name), + return PyOperation::create(name, /*results=*/std::move(resultTypes), /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), - /*regions=*/*regions, location, - std::move(maybeIp)); + /*regions=*/*regions, location, maybeIp); } PyOpView::PyOpView(const py::object &operationObject) @@ -1705,7 +1704,7 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, if (userData.gotException) { std::string message("Exception raised in callback: "); message.append(userData.exceptionWhat); - throw std::runtime_error(std::move(message)); + throw std::runtime_error(message); } } From 545af0522ee53e93e9e5c479091b864987b87e6c Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 2 Jan 2022 22:02:20 +0000 Subject: [PATCH 188/915] Apply clang-tidy fixes for performance-unnecessary-value-param to MLIR (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 13 ++++++------- mlir/lib/Bindings/Python/IRModule.h | 2 +- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b9d31b27b..686153227 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1107,13 +1107,12 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -py::object -PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, - py::list operandList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, - DefaultingPyLocation location, py::object maybeIp) { +py::object PyOpView::buildGeneric( + const py::object &cls, py::list resultTypeList, py::list operandList, + llvm::Optional attributes, + llvm::Optional> successors, + llvm::Optional regions, DefaultingPyLocation location, + const py::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. std::string name = py::cast(cls.attr("OPERATION_NAME")); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index df4aaebf3..117435d63 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -539,7 +539,7 @@ class PyOpView : public PyOperationBase { llvm::Optional attributes, llvm::Optional> successors, llvm::Optional regions, DefaultingPyLocation location, - pybind11::object maybeIp); + const pybind11::object &maybeIp); private: PyOperation &operation; // For efficient, cast-free access from C++ From e6a5749bb67f9aab85ad5c3cad6d8813051ac130 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 3 Jan 2022 06:17:00 +0000 Subject: [PATCH 189/915] Remove misused RAII gil_scoped_release/gil_scoped_acquire: without name they don't have any effect I'm not sure what is the right fix here, but adding a name to all these lead to many segfaults. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D116506 --- mlir/lib/Bindings/Python/IRCore.cpp | 1 - mlir/lib/Bindings/Python/IRModule.cpp | 8 -------- 2 files changed, 9 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 686153227..b39a1ea84 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -868,7 +868,6 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsPrintGenericOpForm(flags); PyFileAccumulator accum(fileObject, binary); - py::gil_scoped_release(); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), accum.getUserData()); mlirOpPrintingFlagsDestroy(flags); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 633ffe4e1..ba6b2d29f 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -35,7 +35,6 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - py::gil_scoped_acquire(); if (loadedDialectModulesCache.contains(dialectNamespace)) return; // Since re-entrancy is possible, make a copy of the search prefixes. @@ -46,7 +45,6 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { - py::gil_scoped_release(); loaded = py::module::import(moduleName.c_str()); } catch (py::error_already_set &e) { if (e.matches(PyExc_ModuleNotFoundError)) { @@ -64,7 +62,6 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { - py::gil_scoped_acquire(); py::object &found = dialectClassMap[dialectNamespace]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + @@ -77,7 +74,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, void PyGlobals::registerOperationImpl(const std::string &operationName, py::object pyClass, py::object rawOpViewClass) { - py::gil_scoped_acquire(); py::object &found = operationClassMap[operationName]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + @@ -90,7 +86,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - py::gil_scoped_acquire(); loadDialectModule(dialectNamespace); // Fast match against the class map first (common case). const auto foundIt = dialectClassMap.find(dialectNamespace); @@ -109,7 +104,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { llvm::Optional PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { { - py::gil_scoped_acquire(); auto foundIt = rawOpViewClassMapCache.find(operationName); if (foundIt != rawOpViewClassMapCache.end()) { if (foundIt->second.is_none()) @@ -126,7 +120,6 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { // Attempt to find from the canonical map and cache. { - py::gil_scoped_acquire(); auto foundIt = rawOpViewClassMap.find(operationName); if (foundIt != rawOpViewClassMap.end()) { if (foundIt->second.is_none()) @@ -143,7 +136,6 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { } void PyGlobals::clearImportCache() { - py::gil_scoped_acquire(); loadedDialectModulesCache.clear(); rawOpViewClassMapCache.clear(); } From 75bcffc4a8f2b3a70814082e323a84ac928d783e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 4 Jan 2022 18:38:30 +0100 Subject: [PATCH 190/915] [mlir] Fix incorrect top-level comment in DialectSparseTensor.cpp --- mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 7de0b8156..c9e3cb639 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -1,4 +1,4 @@ -//===- DialectLinalg.cpp - 'sparse_tensor' dialect submodule --------------===// +//===- DialectSparseTensor.cpp - 'sparse_tensor' dialect submodule --------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 15ab95a607659c554c42d919acfa8a8ae40e75f8 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 3 Jan 2022 16:39:58 -0800 Subject: [PATCH 191/915] [mlir][python] Add bindings for diagnostic handler. I considered multiple approaches for this but settled on this one because I could make the lifetime management work in a reasonably easy way (others had issues with not being able to cast to a Python reference from a C++ constructor). We could stand to have more formatting helpers, but best to get the core mechanism in first. Differential Revision: https://reviews.llvm.org/D116568 --- mlir/lib/Bindings/Python/IRCore.cpp | 165 ++++++++++++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 76 +++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 33 ++++- 3 files changed, 271 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b39a1ea84..1a7eb46f7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -511,6 +511,57 @@ void PyMlirContext::contextExit(const pybind11::object &excType, PyThreadContextEntry::popContext(*this); } +py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { + // Note that ownership is transferred to the delete callback below by way of + // an explicit inc_ref (borrow). + PyDiagnosticHandler *pyHandler = + new PyDiagnosticHandler(get(), std::move(callback)); + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::take_ownership); + pyHandlerObject.inc_ref(); + + // In these C callbacks, the userData is a PyDiagnosticHandler* that is + // guaranteed to be known to pybind. + auto handlerCallback = + +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { + PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); + py::object pyDiagnosticObject = + py::cast(pyDiagnostic, py::return_value_policy::take_ownership); + + auto *pyHandler = static_cast(userData); + bool result = false; + { + // Since this can be called from arbitrary C++ contexts, always get the + // gil. + py::gil_scoped_acquire gil; + try { + result = py::cast(pyHandler->callback(pyDiagnostic)); + } catch (std::exception &e) { + fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", + e.what()); + pyHandler->hadError = true; + } + } + + pyDiagnostic->invalidate(); + return result ? mlirLogicalResultSuccess() : mlirLogicalResultFailure(); + }; + auto deleteCallback = +[](void *userData) { + auto *pyHandler = static_cast(userData); + assert(pyHandler->registeredID && "handler is not registered"); + pyHandler->registeredID.reset(); + + // Decrement reference, balancing the inc_ref() above. + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::reference); + pyHandlerObject.dec_ref(); + }; + + pyHandler->registeredID = mlirContextAttachDiagnosticHandler( + get(), handlerCallback, static_cast(pyHandler), deleteCallback); + return pyHandlerObject; +} + PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { @@ -656,6 +707,78 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { stack.pop_back(); } +//------------------------------------------------------------------------------ +// PyDiagnostic* +//------------------------------------------------------------------------------ + +void PyDiagnostic::invalidate() { + valid = false; + if (materializedNotes) { + for (auto ¬eObject : *materializedNotes) { + PyDiagnostic *note = py::cast(noteObject); + note->invalidate(); + } + } +} + +PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, + py::object callback) + : context(context), callback(std::move(callback)) {} + +PyDiagnosticHandler::~PyDiagnosticHandler() {} + +void PyDiagnosticHandler::detach() { + if (!registeredID) + return; + MlirDiagnosticHandlerID localID = *registeredID; + mlirContextDetachDiagnosticHandler(context, localID); + assert(!registeredID && "should have unregistered"); + // Not strictly necessary but keeps stale pointers from being around to cause + // issues. + context = {nullptr}; +} + +void PyDiagnostic::checkValid() { + if (!valid) { + throw std::invalid_argument( + "Diagnostic is invalid (used outside of callback)"); + } +} + +MlirDiagnosticSeverity PyDiagnostic::getSeverity() { + checkValid(); + return mlirDiagnosticGetSeverity(diagnostic); +} + +PyLocation PyDiagnostic::getLocation() { + checkValid(); + MlirLocation loc = mlirDiagnosticGetLocation(diagnostic); + MlirContext context = mlirLocationGetContext(loc); + return PyLocation(PyMlirContext::forContext(context), loc); +} + +py::str PyDiagnostic::getMessage() { + checkValid(); + py::object fileObject = py::module::import("io").attr("StringIO")(); + PyFileAccumulator accum(fileObject, /*binary=*/false); + mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); + return fileObject.attr("getvalue")(); +} + +py::tuple PyDiagnostic::getNotes() { + checkValid(); + if (materializedNotes) + return *materializedNotes; + intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); + materializedNotes = py::tuple(numNotes); + for (intptr_t i = 0; i < numNotes; ++i) { + MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); + py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); + } + return *materializedNotes; +} + //------------------------------------------------------------------------------ // PyDialect, PyDialectDescriptor, PyDialects //------------------------------------------------------------------------------ @@ -2024,6 +2147,36 @@ class PyOpAttributeMap { //------------------------------------------------------------------------------ void mlir::python::populateIRCore(py::module &m) { + //---------------------------------------------------------------------------- + // Enums. + //---------------------------------------------------------------------------- + py::enum_(m, "DiagnosticSeverity", py::module_local()) + .value("ERROR", MlirDiagnosticError) + .value("WARNING", MlirDiagnosticWarning) + .value("NOTE", MlirDiagnosticNote) + .value("REMARK", MlirDiagnosticRemark); + + //---------------------------------------------------------------------------- + // Mapping of Diagnostics. + //---------------------------------------------------------------------------- + py::class_(m, "Diagnostic", py::module_local()) + .def_property_readonly("severity", &PyDiagnostic::getSeverity) + .def_property_readonly("location", &PyDiagnostic::getLocation) + .def_property_readonly("message", &PyDiagnostic::getMessage) + .def_property_readonly("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> py::str { + if (!self.isValid()) + return ""; + return self.getMessage(); + }); + + py::class_(m, "DiagnosticHandler", py::module_local()) + .def("detach", &PyDiagnosticHandler::detach) + .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) + .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) + .def("__enter__", &PyDiagnosticHandler::contextEnter) + .def("__exit__", &PyDiagnosticHandler::contextExit); + //---------------------------------------------------------------------------- // Mapping of MlirContext. //---------------------------------------------------------------------------- @@ -2079,6 +2232,9 @@ void mlir::python::populateIRCore(py::module &m) { [](PyMlirContext &self, bool value) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) + .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, + py::arg("callback"), + "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { @@ -2204,7 +2360,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("context") = py::none(), kContextGetFileLocationDocstring) .def_static( "fused", - [](const std::vector &pyLocations, llvm::Optional metadata, + [](const std::vector &pyLocations, + llvm::Optional metadata, DefaultingPyMlirContext context) { if (pyLocations.empty()) throw py::value_error("No locations provided"); @@ -2236,6 +2393,12 @@ void mlir::python::populateIRCore(py::module &m) { "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") + .def( + "emit_error", + [](PyLocation &self, std::string message) { + mlirEmitError(self, message.c_str()); + }, + py::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 117435d63..2f354d6d1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -15,6 +15,7 @@ #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" @@ -24,6 +25,8 @@ namespace mlir { namespace python { class PyBlock; +class PyDiagnostic; +class PyDiagnosticHandler; class PyInsertionPoint; class PyLocation; class DefaultingPyLocation; @@ -207,6 +210,10 @@ class PyMlirContext { const pybind11::object &excVal, const pybind11::object &excTb); + /// Attaches a Python callback as a diagnostic handler, returning a + /// registration object (internally a PyDiagnosticHandler). + pybind11::object attachDiagnosticHandler(pybind11::object callback); + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -267,6 +274,75 @@ class BaseContextObject { PyMlirContextRef contextRef; }; +/// Python class mirroring the C MlirDiagnostic struct. Note that these structs +/// are only valid for the duration of a diagnostic callback and attempting +/// to access them outside of that will raise an exception. This applies to +/// nested diagnostics (in the notes) as well. +class PyDiagnostic { +public: + PyDiagnostic(MlirDiagnostic diagnostic) : diagnostic(diagnostic) {} + void invalidate(); + bool isValid() { return valid; } + MlirDiagnosticSeverity getSeverity(); + PyLocation getLocation(); + pybind11::str getMessage(); + pybind11::tuple getNotes(); + +private: + MlirDiagnostic diagnostic; + + void checkValid(); + /// If notes have been materialized from the diagnostic, then this will + /// be populated with the corresponding objects (all castable to + /// PyDiagnostic). + llvm::Optional materializedNotes; + bool valid = true; +}; + +/// Represents a diagnostic handler attached to the context. The handler's +/// callback will be invoked with PyDiagnostic instances until the detach() +/// method is called or the context is destroyed. A diagnostic handler can be +/// the subject of a `with` block, which will detach it when the block exits. +/// +/// Since diagnostic handlers can call back into Python code which can do +/// unsafe things (i.e. recursively emitting diagnostics, raising exceptions, +/// etc), this is generally not deemed to be a great user-level API. Users +/// should generally use some form of DiagnosticCollector. If the handler raises +/// any exceptions, they will just be emitted to stderr and dropped. +/// +/// The unique usage of this class means that its lifetime management is +/// different from most other parts of the API. Instances are always created +/// in an attached state and can transition to a detached state by either: +/// a) The context being destroyed and unregistering all handlers. +/// b) An explicit call to detach(). +/// The object may remain live from a Python perspective for an arbitrary time +/// after detachment, but there is nothing the user can do with it (since there +/// is no way to attach an existing handler object). +class PyDiagnosticHandler { +public: + PyDiagnosticHandler(MlirContext context, pybind11::object callback); + ~PyDiagnosticHandler(); + + bool isAttached() { return registeredID.hasValue(); } + bool getHadError() { return hadError; } + + /// Detaches the handler. Does nothing if not attached. + void detach(); + + pybind11::object contextEnter() { return pybind11::cast(this); } + void contextExit(pybind11::object excType, pybind11::object excVal, + pybind11::object excTb) { + detach(); + } + +private: + MlirContext context; + pybind11::object callback; + llvm::Optional registeredID; + bool hadError = false; + friend class PyMlirContext; +}; + /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in /// order to differentiate it from the `Dialect` base class which is extended by /// plugins which extend dialect functionality through extension python code. diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index e61e34a17..affe54c3e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,7 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence +from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple from typing import overload @@ -43,6 +43,9 @@ __all__ = [ "Dialect", "DialectDescriptor", "Dialects", + "Diagnostic", + "DiagnosticHandler", + "DiagnosticSeverity", "DictAttr", "F16Type", "F32Type", @@ -425,8 +428,9 @@ class Context: def _get_live_count() -> int: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... + def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(name: dialect_name: str) -> "DialectDescriptor": ... + def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ... def is_registered_operation(self, operation_name: str) -> bool: ... def __enter__(self) -> "Context": ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @@ -479,6 +483,31 @@ class Dialects: def __getattr__(self, arg0: str) -> "Dialect": ... def __getitem__(self, arg0: str) -> "Dialect": ... +class Diagnostic: + @property + def severity(self) -> "DiagnosticSeverity": ... + @property + def location(self) -> "Location": ... + @property + def message(self) -> str: ... + @property + def notes(self) -> Tuple["Diagnostic"]: ... + +class DiagnosticHandler: + def detach(self) -> None: ... + @property + def attached(self) -> bool: ... + @property + def had_error(self) -> bool: ... + def __enter__(self) -> "DiagnosticHandler": ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + +class DiagnosticSeverity: + ERROR: "DiagnosticSeverity" + WARNING: "DiagnosticSeverity" + NOTE: "DiagnosticSeverity" + REMARK: "DiagnosticSeverity" + # TODO: Auto-generated. Audit and fix. class DictAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... From e89ed2c80a2d6969d8a07e1dde9d4704391704e7 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 4 Jan 2022 15:37:33 -0800 Subject: [PATCH 192/915] [mlir] Retain metadata for single loc fusedloc If a fusedloc is created with a single location then no fusedloc was previously created and single location returned instead. In the case where there is a metadata associated with the location this results in discarding the metadata. Instead only canonicalize where there is no loss of information. Differential Revision: https://reviews.llvm.org/D115605 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1a7eb46f7..1a9604882 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2363,8 +2363,6 @@ void mlir::python::populateIRCore(py::module &m) { [](const std::vector &pyLocations, llvm::Optional metadata, DefaultingPyMlirContext context) { - if (pyLocations.empty()) - throw py::value_error("No locations provided"); llvm::SmallVector locations; locations.reserve(pyLocations.size()); for (auto &pyLocation : pyLocations) From 8e0c07e1d5549f2e55d399a93600c066e061e0a9 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 5 Jan 2022 13:06:45 +0100 Subject: [PATCH 193/915] [mlir] Use public PybindAdaptors in Linalg dialect bindings Previously, the Python bindings for the Linalg dialect relied on the internal implementation of core bindings. Most of that functionality was moved, and the remaining one does not need access to the implementation: it used to accept a dialect pointer as argument, but it can always be extracted from the operation that it also accepts; operations are available through PybindAdaptors in an opaque way. Change the bindings in that direction. This enables the decoupling of the Linalg dialect Python extension from the core IR Python extension. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D116649 --- mlir/include/mlir-c/Dialect/Linalg.h | 4 ++-- mlir/lib/Bindings/Python/DialectLinalg.cpp | 13 +++---------- mlir/lib/CAPI/Dialect/Linalg.cpp | 11 +++++------ mlir/python/mlir/dialects/_linalg_ops_ext.py | 3 +-- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 3 +-- 5 files changed, 12 insertions(+), 22 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 27f2f7bc8..2fe1872be 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -18,9 +18,9 @@ extern "C" { #endif /// Apply the special region builder for the builtin named Linalg op. -/// Assert that `op` is a builtin named Linalg op. +/// Assert that `mlirOp` is a builtin named Linalg op. MLIR_CAPI_EXPORTED void -mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, MlirOperation op); +mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index a2a54249e..a16825615 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -7,24 +7,17 @@ //===----------------------------------------------------------------------===// #include "Dialects.h" -#include "IRModule.h" #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" - -// TODO: Port this to operate only on the public PybindAdaptors.h -#include "PybindUtils.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; -using namespace mlir; -using namespace mlir::python; void mlir::python::populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", - [](PyDialectDescriptor &dialect, PyOperation &op) { - mlirLinalgFillBuiltinNamedOpRegion(dialect.get(), op.get()); - }, - py::arg("dialect"), py::arg("op"), + [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); }, + py::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index dd796167b..6c5ba9a88 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -15,20 +15,19 @@ using namespace mlir::linalg; /// Apply the special region builder for the builtin named Linalg op. /// Assert that `op` is a builtin named Linalg op. -void mlirLinalgFillBuiltinNamedOpRegion(MlirDialect linalgDialect, - MlirOperation mlirOp) { +void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Operation *op = unwrap(mlirOp); - + auto linalgOp = cast(op); + auto *dialect = static_cast(linalgOp->getDialect()); LinalgDialect::RegionBuilderFunType fun = - static_cast(unwrap(linalgDialect)) - ->getRegionBuilder(op->getName().getStringRef()); + dialect->getRegionBuilder(op->getName().getStringRef()); + assert(fun && "Expected a builtin named Linalg op."); assert(op->getNumRegions() == 1 && "Expected Linalg op with 1 region"); assert(op->getRegion(0).getBlocks().empty() && "Expected Linalg op with 0 blocks"); SmallVector argTypes; - auto linalgOp = cast(op); for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index d6c57547e..90f922724 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -34,8 +34,7 @@ def __init__(self, output: Value, value: Value, *, loc=None, ip=None): loc=loc, ip=ip) OpView.__init__(self, op) - linalgDialect = Context.current.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, self.operation) + fill_builtin_region(self.operation) class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 933c26ad9..c3cfdfac9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -173,8 +173,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - linalgDialect = ctx.get_dialect_descriptor("linalg") - fill_builtin_region(linalgDialect, named_op.operation) + fill_builtin_region(named_op.operation) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. From 474bec25f0840da27c6f7a337f6ebc56664695f3 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 3 Jan 2022 19:01:07 +0100 Subject: [PATCH 194/915] [mlir] Introduce C API for the Quantization dialect types Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D116546 --- mlir/include/mlir-c/Dialect/Quant.h | 199 ++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 ++ mlir/lib/CAPI/Dialect/Quant.cpp | 224 +++++++++++++++++++++++++++ 3 files changed, 432 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/Quant.h create mode 100644 mlir/lib/CAPI/Dialect/Quant.cpp diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h new file mode 100644 index 000000000..c45d93af4 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -0,0 +1,199 @@ +//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_QUANT_H +#define MLIR_C_DIALECT_QUANT_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(quant, quant); + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a quantization dialect type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAQuantizedType(MlirType type); + +/// Returns the bit flag used to indicate signedness of a quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetSignedFlag(); + +/// Returns the minimum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMinimumForInteger( + bool isSigned, unsigned integralWidth); + +/// Returns the maximum possible value stored by a quantized type. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMaximumForInteger( + bool isSigned, unsigned integralWidth); + +/// Gets the original type approximated by the given quantized type. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetExpressedType(MlirType type); + +/// Gets the flags associated with the given quantized type. +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetFlags(MlirType type); + +/// Returns `true` if the given type is signed, `false` otherwise. +MLIR_CAPI_EXPORTED bool mlirQuantizedTypeIsSigned(MlirType type); + +/// Returns the underlying type used to store the values. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeGetStorageType(MlirType type); + +/// Returns the minimum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type); + +/// Returns the maximum value that the storage type of the given quantized type +/// can take. +MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type); + +/// Returns the integral bitwidth that the storage type of the given quantized +/// type can represent exactly. +MLIR_CAPI_EXPORTED unsigned +mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type); + +/// Returns `true` if the `candidate` type is compatible with the given +/// quantized `type`. +MLIR_CAPI_EXPORTED bool +mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate); + +/// Returns the element type of the given quantized type as another quantized +/// type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeGetQuantizedElementType(MlirType type); + +/// Casts from a type based on the storage type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the storage type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToStorageType(MlirType type); + +/// Casts from a type based on the expressed type of the given type to a +/// corresponding type based on the given type. Returns a null type if the cast +/// is not valid. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate); + +/// Casts from a type based on a quantized type to a corresponding typed based +/// on the expressed type. Returns a null type if the cast is not valid. +MLIR_CAPI_EXPORTED MlirType mlirQuantizedTypeCastToExpressedType(MlirType type); + +/// Casts from a type based on the expressed type of the given quantized type to +/// equivalent type based on storage type of the same quantized type. +MLIR_CAPI_EXPORTED MlirType +mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate); + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is an AnyQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAAnyQuantizedType(MlirType type); + +/// Creates an instance of AnyQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirAnyQuantizedTypeGet(unsigned flags, + MlirType storageType, + MlirType expressedType, + int64_t storageTypeMin, + int64_t storageTypeMax); + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedType(MlirType type); + +/// Creates an instance of UniformQuantizedType with the given parameters in the +/// same context as `storageType` and returns it. The instance is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the scale of the given uniform quantized type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedTypeGetScale(MlirType type); + +/// Returns the zero point of the given uniform quantized type. +MLIR_CAPI_EXPORTED int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type); + +/// Returns `true` if the given uniform quantized type is fixed-point. +MLIR_CAPI_EXPORTED bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedPerAxisType. +MLIR_CAPI_EXPORTED bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type); + +/// Creates an instance of UniformQuantizedPerAxisType with the given parameters +/// in the same context as `storageType` and returns it. `scales` and +/// `zeroPoints` point to `nDims` number of elements. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of axes in the given quantized per-axis type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type); + +/// Returns `pos`-th scale of the given quantized per-axis type. +MLIR_CAPI_EXPORTED double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, + intptr_t pos); + +/// Returns `pos`-th zero point of the given quantized per-axis type. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos); + +/// Returns the index of the quantized dimension in the given quantized per-axis +/// type. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type); + +/// Returns `true` if the given uniform quantized per-axis type is fixed-point. +MLIR_CAPI_EXPORTED bool +mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a CalibratedQuantizedType. +MLIR_CAPI_EXPORTED bool mlirTypeIsACalibratedQuantizedType(MlirType type); + +/// Creates an instance of CalibratedQuantizedType with the given parameters +/// in the same context as `expressedType` and returns it. The instance is owned +/// by the context. +MLIR_CAPI_EXPORTED MlirType +mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, double max); + +/// Returns the min value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMin(MlirType type); + +/// Returns the max value of the given calibrated quantized type. +MLIR_CAPI_EXPORTED double mlirCalibratedQuantizedTypeGetMax(MlirType type); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_QUANT_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 4f11bc52c..1d3e2727a 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -97,3 +97,12 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor MLIRCAPIIR MLIRTensor ) + +add_mlir_upstream_c_api_library(MLIRCAPIQuant + Quant.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRQuant +) diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp new file mode 100644 index 000000000..3fc00a72d --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -0,0 +1,224 @@ +//===- LLVM.cpp - C Interface for Quant dialect ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Quant.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Quant/QuantOps.h" +#include "mlir/Dialect/Quant/QuantTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) + +//===---------------------------------------------------------------------===// +// QuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +unsigned mlirQuantizedTypeGetSignedFlag() { + return quant::QuantizationFlags::Signed; +} + +int64_t mlirQuantizedTypeGetDefaultMinimumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMinimumForInteger(isSigned, + integralWidth); +} + +int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, + unsigned integralWidth) { + return quant::QuantizedType::getDefaultMaximumForInteger(isSigned, + integralWidth); +} + +MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { + return wrap(unwrap(type).cast().getExpressedType()); +} + +unsigned mlirQuantizedTypeGetFlags(MlirType type) { + return unwrap(type).cast().getFlags(); +} + +bool mlirQuantizedTypeIsSigned(MlirType type) { + return unwrap(type).cast().isSigned(); +} + +MlirType mlirQuantizedTypeGetStorageType(MlirType type) { + return wrap(unwrap(type).cast().getStorageType()); +} + +int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { + return unwrap(type).cast().getStorageTypeMin(); +} + +int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { + return unwrap(type).cast().getStorageTypeMax(); +} + +unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { + return unwrap(type) + .cast() + .getStorageTypeIntegralWidth(); +} + +bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, + MlirType candidate) { + return unwrap(type).cast().isCompatibleExpressedType( + unwrap(candidate)); +} + +MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { + return wrap(quant::QuantizedType::getQuantizedElementType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, + MlirType candidate) { + return wrap(unwrap(type).cast().castFromStorageType( + unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { + return wrap(quant::QuantizedType::castToStorageType( + unwrap(type).cast())); +} + +MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, + MlirType candidate) { + return wrap(unwrap(type).cast().castFromExpressedType( + unwrap(candidate))); +} + +MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { + return wrap(quant::QuantizedType::castToExpressedType(unwrap(type))); +} + +MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, + MlirType candidate) { + return wrap( + unwrap(type).cast().castExpressedToStorageType( + unwrap(candidate))); +} + +//===---------------------------------------------------------------------===// +// AnyQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAAnyQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::AnyQuantizedType::get(flags, unwrap(storageType), + unwrap(expressedType), + storageTypeMin, storageTypeMax)); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, + MlirType expressedType, double scale, + int64_t zeroPoint, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedType::get( + flags, unwrap(storageType), unwrap(expressedType), scale, zeroPoint, + storageTypeMin, storageTypeMax)); +} + +double mlirUniformQuantizedTypeGetScale(MlirType type) { + return unwrap(type).cast().getScale(); +} + +int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { + return unwrap(type).cast().getZeroPoint(); +} + +bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { + return unwrap(type).cast().isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// UniformQuantizedPerAxisType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirUniformQuantizedPerAxisTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + intptr_t nDims, double *scales, int64_t *zeroPoints, + int32_t quantizedDimension, int64_t storageTypeMin, + int64_t storageTypeMax) { + return wrap(quant::UniformQuantizedPerAxisType::get( + flags, unwrap(storageType), unwrap(expressedType), + llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims), + quantizedDimension, storageTypeMin, storageTypeMax)); +} + +intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { + return unwrap(type) + .cast() + .getScales() + .size(); +} + +double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { + return unwrap(type) + .cast() + .getScales()[pos]; +} + +int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, + intptr_t pos) { + return unwrap(type) + .cast() + .getZeroPoints()[pos]; +} + +int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { + return unwrap(type) + .cast() + .getQuantizedDimension(); +} + +bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { + return unwrap(type).cast().isFixedPoint(); +} + +//===---------------------------------------------------------------------===// +// CalibratedQuantizedType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsACalibratedQuantizedType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, + double max) { + return wrap( + quant::CalibratedQuantizedType::get(unwrap(expressedType), min, max)); +} + +double mlirCalibratedQuantizedTypeGetMin(MlirType type) { + return unwrap(type).cast().getMin(); +} + +double mlirCalibratedQuantizedTypeGetMax(MlirType type) { + return unwrap(type).cast().getMax(); +} From 3b9ca44c5635acf780f77f18e07fa2cb7c48eec3 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 4 Jan 2022 17:43:44 +0100 Subject: [PATCH 195/915] [mlir] Introduce Python bindings for the quantization dialect So far, only the custom dialect types are exposed. The build and packaging is same as for Linalg and SparseTensor, and in need of refactoring that is beyond the scope of this patch. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D116605 --- .../mlir/Bindings/Python/PybindAdaptors.h | 2 + mlir/lib/Bindings/Python/DialectQuant.cpp | 307 ++++++++++++++++++ mlir/lib/Bindings/Python/Dialects.h | 2 + mlir/lib/Bindings/Python/MainModule.cpp | 2 + mlir/python/CMakeLists.txt | 11 + .../mlir/_mlir_libs/_mlir/dialects/quant.pyi | 123 +++++++ mlir/python/mlir/dialects/quant.py | 5 + 7 files changed, 452 insertions(+) create mode 100644 mlir/lib/Bindings/Python/DialectQuant.cpp create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi create mode 100644 mlir/python/mlir/dialects/quant.py diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 61b982193..811e54ab4 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -309,6 +309,8 @@ class pure_subclass { return *this; } + py::object get_class() const { return thisClass; } + protected: py::object superClass; py::object thisClass; diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp new file mode 100644 index 000000000..f2fad706a --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -0,0 +1,307 @@ +//===- DialectQuant.cpp - 'quant' dialect submodule -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Dialects.h" +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python::adaptors; + +void mlir::python::populateDialectQuantSubmodule(const py::module &m, + const py::module &irModule) { + auto typeClass = irModule.attr("Type"); + + //===-------------------------------------------------------------------===// + // QuantizedType + //===-------------------------------------------------------------------===// + + auto quantizedType = mlir_type_subclass(m, "QuantizedType", + mlirTypeIsAQuantizedType, typeClass); + quantizedType.def_staticmethod( + "default_minimum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMinimumForInteger(isSigned, + integralWidth); + }, + "Default minimum value for the integer with the specified signedness and " + "bit width.", + py::arg("is_signed"), py::arg("integral_width")); + quantizedType.def_staticmethod( + "default_maximum_for_integer", + [](bool isSigned, unsigned integralWidth) { + return mlirQuantizedTypeGetDefaultMaximumForInteger(isSigned, + integralWidth); + }, + "Default maximum value for the integer with the specified signedness and " + "bit width.", + py::arg("is_signed"), py::arg("integral_width")); + quantizedType.def_property_readonly( + "expressed_type", + [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, + "Type expressed by this quantized type."); + quantizedType.def_property_readonly( + "flags", [](MlirType type) { return mlirQuantizedTypeGetFlags(type); }, + "Flags of this quantized type (named accessors should be preferred to " + "this)"); + quantizedType.def_property_readonly( + "is_signed", + [](MlirType type) { return mlirQuantizedTypeIsSigned(type); }, + "Signedness of this quantized type."); + quantizedType.def_property_readonly( + "storage_type", + [](MlirType type) { return mlirQuantizedTypeGetStorageType(type); }, + "Storage type backing this quantized type."); + quantizedType.def_property_readonly( + "storage_type_min", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMin(type); }, + "The minimum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_max", + [](MlirType type) { return mlirQuantizedTypeGetStorageTypeMax(type); }, + "The maximum value held by the storage type of this quantized type."); + quantizedType.def_property_readonly( + "storage_type_integral_width", + [](MlirType type) { + return mlirQuantizedTypeGetStorageTypeIntegralWidth(type); + }, + "The bitwidth of the storage type of this quantized type."); + quantizedType.def( + "is_compatible_expressed_type", + [](MlirType type, MlirType candidate) { + return mlirQuantizedTypeIsCompatibleExpressedType(type, candidate); + }, + "Checks whether the candidate type can be expressed by this quantized " + "type.", + py::arg("candidate")); + quantizedType.def_property_readonly( + "quantized_element_type", + [](MlirType type) { + return mlirQuantizedTypeGetQuantizedElementType(type); + }, + "Element type of this quantized type expressed as quantized type."); + quantizedType.def( + "cast_from_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the storage type of this quantized type to a " + "corresponding type based on the quantized type. Raises TypeError if the " + "cast is not valid.", + py::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_storage_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToStorageType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the storage type of this quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("type")); + quantizedType.def( + "cast_from_expressed_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastFromExpressedType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("candidate")); + quantizedType.def_staticmethod( + "cast_to_expressed_type", + [](MlirType type) { + MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on a quantized type to a corresponding type " + "based on the expressed type of this quantized type. Raises TypeError if " + "the cast is not valid.", + py::arg("type")); + quantizedType.def( + "cast_expressed_to_storage_type", + [](MlirType type, MlirType candidate) { + MlirType castResult = + mlirQuantizedTypeCastExpressedToStorageType(type, candidate); + if (!mlirTypeIsNull(castResult)) + return castResult; + throw py::type_error("Invalid cast."); + }, + "Casts from a type based on the expressed type of this quantized type to " + "a corresponding type based on the storage type. Raises TypeError if the " + "cast is not valid.", + py::arg("candidate")); + + quantizedType.get_class().attr("FLAG_SIGNED") = + mlirQuantizedTypeGetSignedFlag(); + + //===-------------------------------------------------------------------===// + // AnyQuantizedType + //===-------------------------------------------------------------------===// + + auto anyQuantizedType = + mlir_type_subclass(m, "AnyQuantizedType", mlirTypeIsAAnyQuantizedType, + quantizedType.get_class()); + anyQuantizedType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of AnyQuantizedType in the same context as the " + "provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("storage_type_min"), + py::arg("storage_type_max")); + + //===-------------------------------------------------------------------===// + // UniformQuantizedType + //===-------------------------------------------------------------------===// + + auto uniformQuantizedType = mlir_type_subclass( + m, "UniformQuantizedType", mlirTypeIsAUniformQuantizedType, + quantizedType.get_class()); + uniformQuantizedType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, double scale, int64_t zeroPoint, + int64_t storageTypeMin, int64_t storageTypeMax) { + return cls(mlirUniformQuantizedTypeGet(flags, storageType, + expressedType, scale, zeroPoint, + storageTypeMin, storageTypeMax)); + }, + "Gets an instance of UniformQuantizedType in the same context as the " + "provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), + py::arg("storage_type_min"), py::arg("storage_type_max")); + uniformQuantizedType.def_property_readonly( + "scale", + [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, + "The scale designates the difference between the real values " + "corresponding to consecutive quantized values differing by 1."); + uniformQuantizedType.def_property_readonly( + "zero_point", + [](MlirType type) { return mlirUniformQuantizedTypeGetZeroPoint(type); }, + "The storage value corresponding to the real value 0 in the affine " + "equation."); + uniformQuantizedType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { return mlirUniformQuantizedTypeIsFixedPoint(type); }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // UniformQuantizedPerAxisType + //===-------------------------------------------------------------------===// + auto uniformQuantizedPerAxisType = mlir_type_subclass( + m, "UniformQuantizedPerAxisType", mlirTypeIsAUniformQuantizedPerAxisType, + quantizedType.get_class()); + uniformQuantizedPerAxisType.def_classmethod( + "get", + [](py::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, std::vector scales, + std::vector zeroPoints, int32_t quantizedDimension, + int64_t storageTypeMin, int64_t storageTypeMax) { + if (scales.size() != zeroPoints.size()) + throw py::value_error( + "Mismatching number of scales and zero points."); + auto nDims = static_cast(scales.size()); + return cls(mlirUniformQuantizedPerAxisTypeGet( + flags, storageType, expressedType, nDims, scales.data(), + zeroPoints.data(), quantizedDimension, storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedPerAxisType in the same context as " + "the provided storage type.", + py::arg("cls"), py::arg("flags"), py::arg("storage_type"), + py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), + py::arg("quantized_dimension"), py::arg("storage_type_min"), + py::arg("storage_type_max")); + uniformQuantizedPerAxisType.def_property_readonly( + "scales", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector scales; + scales.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); + scales.push_back(scale); + } + }, + "The scales designate the difference between the real values " + "corresponding to consecutive quantized values differing by 1. The ith " + "scale corresponds to the ith slice in the quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "zero_points", + [](MlirType type) { + intptr_t nDim = mlirUniformQuantizedPerAxisTypeGetNumDims(type); + std::vector zeroPoints; + zeroPoints.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + int64_t zeroPoint = + mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); + zeroPoints.push_back(zeroPoint); + } + }, + "the storage values corresponding to the real value 0 in the affine " + "equation. The ith zero point corresponds to the ith slice in the " + "quantized_dimension."); + uniformQuantizedPerAxisType.def_property_readonly( + "quantized_dimension", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(type); + }, + "Specifies the dimension of the shape that the scales and zero points " + "correspond to."); + uniformQuantizedPerAxisType.def_property_readonly( + "is_fixed_point", + [](MlirType type) { + return mlirUniformQuantizedPerAxisTypeIsFixedPoint(type); + }, + "Fixed point values are real numbers divided by a scale."); + + //===-------------------------------------------------------------------===// + // CalibratedQuantizedType + //===-------------------------------------------------------------------===// + + auto calibratedQuantizedType = mlir_type_subclass( + m, "CalibratedQuantizedType", mlirTypeIsACalibratedQuantizedType, + quantizedType.get_class()); + calibratedQuantizedType.def_classmethod( + "get", + [](py::object cls, MlirType expressedType, double min, double max) { + return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); + }, + "Gets an instance of CalibratedQuantizedType in the same context as the " + "provided expressed type.", + py::arg("cls"), py::arg("expressed_type"), py::arg("min"), + py::arg("max")); + calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMin(type); + }); + calibratedQuantizedType.def_property_readonly("max", [](MlirType type) { + return mlirCalibratedQuantizedTypeGetMax(type); + }); +} diff --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h index c1725074c..a130903c6 100644 --- a/mlir/lib/Bindings/Python/Dialects.h +++ b/mlir/lib/Bindings/Python/Dialects.h @@ -17,6 +17,8 @@ namespace python { void populateDialectLinalgSubmodule(pybind11::module m); void populateDialectSparseTensorSubmodule(const pybind11::module &m, const pybind11::module &irModule); +void populateDialectQuantSubmodule(const pybind11::module &m, + const pybind11::module &irModule); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 896ee432e..f55482676 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -107,4 +107,6 @@ PYBIND11_MODULE(_mlir, m) { populateDialectLinalgSubmodule(linalgModule); populateDialectSparseTensorSubmodule( dialectsModule.def_submodule("sparse_tensor"), irModule); + populateDialectQuantSubmodule(dialectsModule.def_submodule("quant"), + irModule); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index f7b84b033..1fb98c540 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -25,6 +25,8 @@ declare_mlir_python_sources(MLIRPythonSources.Core _mlir_libs/_mlir/__init__.pyi _mlir_libs/_mlir/ir.pyi _mlir_libs/_mlir/passmanager.pyi + # TODO: this should be split out into a separate library. + _mlir_libs/_mlir/dialects/quant.pyi ) declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine @@ -115,6 +117,13 @@ declare_mlir_dialect_python_bindings( dialects/_memref_ops_ext.py DIALECT_NAME memref) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.quant + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/quant.py) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -184,6 +193,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core SOURCES DialectLinalg.cpp # TODO: Break this out. DialectSparseTensor.cpp # TODO: Break this out. + DialectQuant.cpp # TODO: Break this out. MainModule.cpp IRAffine.cpp IRAttributes.cpp @@ -212,6 +222,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPILinalg # TODO: Remove when above is removed. MLIRCAPISparseTensor # TODO: Remove when above is removed. MLIRCAPIStandard + MLIRCAPIQuant # TODO: Remove when above is removed. ) declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi new file mode 100644 index 000000000..c9c66d52b --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -0,0 +1,123 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import List + +from mlir.ir import Type + +__all__ = [ + "QuantizedType", + "AnyQuantizedType", + "UniformQuantizedType", + "UniformQuantizedPerAxisType", + "CalibratedQuantizedType", +] + +class QuantizedType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def default_minimum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @staticmethod + def default_maximum_for_integer(is_signed: bool, integral_width: int) -> int: + ... + + @property + def expressed_type(self) -> Type: ... + + @property + def flags(self) -> int: ... + + @property + def is_signed(self) -> bool: ... + + @property + def storage_type(self) -> Type: ... + + @property + def storage_type_min(self) -> int: ... + + @property + def storage_type_max(self) -> int: ... + + @property + def storage_type_integral_width(self) -> int: ... + + def is_compatible_expressed_type(self, candidate: Type) -> bool: ... + + @property + def quantized_element_type(self) -> Type: ... + + def cast_from_storage_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_storage_type(type: Type) -> Type: ... + + def cast_from_expressed_type(self, candidate: Type) -> Type: ... + + @staticmethod + def cast_to_expressed_type(type: Type) -> Type: ... + + def cast_expressed_to_storage_type(self, candidate: Type) -> Type: ... + + +class AnyQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + storage_type_min: int, storage_type_max: int) -> Type: + ... + + +class UniformQuantizedType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scale: float, zero_point: int, storage_type_min: int, + storage_type_max: int) -> Type: ... + + @property + def scale(self) -> float: ... + + @property + def zero_point(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +class UniformQuantizedPerAxisType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: List[float], zero_points: List[int], quantized_dimension: int, + storage_type_min: int, storage_type_max: int): + ... + + @property + def scales(self) -> List[float]: ... + + @property + def zero_points(self) -> List[float]: ... + + @property + def quantized_dimension(self) -> int: ... + + @property + def is_fixed_point(self) -> bool: ... + + +def CalibratedQuantizedType(QuantizedType): + + @classmethod + def get(cls, expressed_type: Type, min: float, max: float): ... + + @property + def min(self) -> float: ... + + @property + def max(self) -> float: ... \ No newline at end of file diff --git a/mlir/python/mlir/dialects/quant.py b/mlir/python/mlir/dialects/quant.py new file mode 100644 index 000000000..92990b1c5 --- /dev/null +++ b/mlir/python/mlir/dialects/quant.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._mlir_libs._mlir.dialects.quant import * From dc9d6e5f97ff3a24a7de8c4cc5f4078462354952 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 5 Jan 2022 11:21:21 +0100 Subject: [PATCH 196/915] [mlir] Split out Python bindings for dialects into separate libs Historically, the bindings for the Linalg dialect were included into the "core" bindings library because they depended on the C++ implementation of the "core" bindings. The other dialects followed the pattern. Now that this dependency is gone, split out each dialect into a separate Python extension library. Depends On D116649, D116605 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D116662 --- mlir/lib/Bindings/Python/DialectLinalg.cpp | 9 +++- mlir/lib/Bindings/Python/DialectQuant.cpp | 16 +++--- .../Bindings/Python/DialectSparseTensor.cpp | 14 ++--- mlir/lib/Bindings/Python/Dialects.h | 26 ---------- mlir/lib/Bindings/Python/MainModule.cpp | 10 ---- mlir/python/CMakeLists.txt | 51 +++++++++++++++---- mlir/python/mlir/dialects/_linalg_ops_ext.py | 2 +- mlir/python/mlir/dialects/linalg/__init__.py | 25 +++++---- .../dialects/linalg/opdsl/lang/emitter.py | 3 +- mlir/python/mlir/dialects/quant.py | 2 +- mlir/python/mlir/dialects/sparse_tensor.py | 2 +- 11 files changed, 82 insertions(+), 78 deletions(-) delete mode 100644 mlir/lib/Bindings/Python/Dialects.h diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index a16825615..2e54ebeb6 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -6,14 +6,13 @@ // //===----------------------------------------------------------------------===// -#include "Dialects.h" #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; -void mlir::python::populateDialectLinalgSubmodule(py::module m) { +static void populateDialectLinalgSubmodule(py::module m) { m.def( "fill_builtin_region", [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); }, @@ -21,3 +20,9 @@ void mlir::python::populateDialectLinalgSubmodule(py::module m) { "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } + +PYBIND11_MODULE(_mlirDialectsLinalg, m) { + m.doc() = "MLIR Linalg dialect."; + + populateDialectLinalgSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index f2fad706a..844cbec4e 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "Dialects.h" #include "mlir-c/Dialect/Quant.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" @@ -16,16 +15,13 @@ using namespace llvm; using namespace mlir; using namespace mlir::python::adaptors; -void mlir::python::populateDialectQuantSubmodule(const py::module &m, - const py::module &irModule) { - auto typeClass = irModule.attr("Type"); - +static void populateDialectQuantSubmodule(const py::module &m) { //===-------------------------------------------------------------------===// // QuantizedType //===-------------------------------------------------------------------===// - auto quantizedType = mlir_type_subclass(m, "QuantizedType", - mlirTypeIsAQuantizedType, typeClass); + auto quantizedType = + mlir_type_subclass(m, "QuantizedType", mlirTypeIsAQuantizedType); quantizedType.def_staticmethod( "default_minimum_for_integer", [](bool isSigned, unsigned integralWidth) { @@ -305,3 +301,9 @@ void mlir::python::populateDialectQuantSubmodule(const py::module &m, return mlirCalibratedQuantizedTypeGetMax(type); }); } + +PYBIND11_MODULE(_mlirDialectsQuant, m) { + m.doc() = "MLIR Quantization dialect"; + + populateDialectQuantSubmodule(m); +} \ No newline at end of file diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index c9e3cb639..b24d024d1 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "Dialects.h" #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" @@ -16,18 +15,14 @@ using namespace llvm; using namespace mlir; using namespace mlir::python::adaptors; -void mlir::python::populateDialectSparseTensorSubmodule( - const py::module &m, const py::module &irModule) { - auto attributeClass = irModule.attr("Attribute"); - +static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "DimLevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); mlir_attribute_subclass(m, "EncodingAttr", - mlirAttributeIsASparseTensorEncodingAttr, - attributeClass) + mlirAttributeIsASparseTensorEncodingAttr) .def_classmethod( "get", [](py::object cls, @@ -72,3 +67,8 @@ void mlir::python::populateDialectSparseTensorSubmodule( return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); }); } + +PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { + m.doc() = "MLIR SparseTensor dialect."; + populateDialectSparseTensorSubmodule(m); +} diff --git a/mlir/lib/Bindings/Python/Dialects.h b/mlir/lib/Bindings/Python/Dialects.h deleted file mode 100644 index a130903c6..000000000 --- a/mlir/lib/Bindings/Python/Dialects.h +++ /dev/null @@ -1,26 +0,0 @@ -//===- Dialects.h - Declaration for dialect submodule factories -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_DIALECTS_H -#define MLIR_BINDINGS_PYTHON_DIALECTS_H - -#include - -namespace mlir { -namespace python { - -void populateDialectLinalgSubmodule(pybind11::module m); -void populateDialectSparseTensorSubmodule(const pybind11::module &m, - const pybind11::module &irModule); -void populateDialectQuantSubmodule(const pybind11::module &m, - const pybind11::module &irModule); - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_DIALECTS_H diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index f55482676..1d6d8fa01 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -10,7 +10,6 @@ #include "PybindUtils.h" -#include "Dialects.h" #include "Globals.h" #include "IRModule.h" #include "Pass.h" @@ -100,13 +99,4 @@ PYBIND11_MODULE(_mlir, m) { auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); populatePassManagerSubmodule(passModule); - - // Define and populate dialect submodules. - auto dialectsModule = m.def_submodule("dialects"); - auto linalgModule = dialectsModule.def_submodule("linalg"); - populateDialectLinalgSubmodule(linalgModule); - populateDialectSparseTensorSubmodule( - dialectsModule.def_submodule("sparse_tensor"), irModule); - populateDialectQuantSubmodule(dialectsModule.def_submodule("quant"), - irModule); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1fb98c540..60d60d4af 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -25,8 +25,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core _mlir_libs/_mlir/__init__.pyi _mlir_libs/_mlir/ir.pyi _mlir_libs/_mlir/passmanager.pyi - # TODO: this should be split out into a separate library. - _mlir_libs/_mlir/dialects/quant.pyi ) declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine @@ -122,7 +120,8 @@ declare_mlir_python_sources( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" SOURCES - dialects/quant.py) + dialects/quant.py + _mlir_libs/_mlir/dialects/quant.pyi) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -191,9 +190,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - DialectLinalg.cpp # TODO: Break this out. - DialectSparseTensor.cpp # TODO: Break this out. - DialectQuant.cpp # TODO: Break this out. MainModule.cpp IRAffine.cpp IRAttributes.cpp @@ -205,7 +201,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Pass.cpp # Headers must be included explicitly so they are installed. - Dialects.h Globals.h IRModule.h Pass.h @@ -219,10 +214,46 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects - MLIRCAPILinalg # TODO: Remove when above is removed. - MLIRCAPISparseTensor # TODO: Remove when above is removed. MLIRCAPIStandard - MLIRCAPIQuant # TODO: Remove when above is removed. +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind + MODULE_NAME _mlirDialectsLinalg + ADD_TO_PARENT MLIRPythonSources.Dialects.linalg + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectLinalg.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPILinalg +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind + MODULE_NAME _mlirDialectsQuant + ADD_TO_PARENT MLIRPythonSources.Dialects.quant + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectQuant.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIQuant +) + +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind + MODULE_NAME _mlirDialectsSparseTensor + ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectSparseTensor.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPISparseTensor ) declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 90f922724..167a9232d 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -6,7 +6,7 @@ from typing import Optional, Sequence, Union from ..ir import * from ._ods_common import get_default_loc_context - from .._mlir_libs._mlir.dialects.linalg import fill_builtin_region + from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 976718337..eadb8420c 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -2,6 +2,9 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Re-export the objects provided by pybind. +from ..._mlir_libs._mlirDialectsLinalg import * + # These are the backing OpView classes generated from the linalg tablegen # definitions following these steps: # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. @@ -15,39 +18,39 @@ # C=TensorDef(U, S.M, S.N, output=True)): # ``` # using the linalg-py eDSL. -# The linalg-py eDSL builds a python representation (PyRepr) that is +# The linalg-py eDSL builds a python representation (PyRepr) that is # used in following ways: # 1. PyRepr -> YAML to generate the C++ and Python .td files. These # then turn into the core C++ Op classes and Python OpView classes -# respectively (made available in _linalg_ops_gen). The generic OpView class +# respectively (made available in _linalg_ops_gen). The generic OpView class # mechanism makes the C++ classes available to python through the CAPI. # PyRepr -> YAML currently occurs before compiler compile time. # The other steps in this category occur at compiler compile time. -# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the +# 2. PyRepr -> linalg.core_named_ops calls: piggybacks on the # _linalg_ops_gen classes and the OpView mechanism to build IR at # runtime in python: # a. by default, the Named Op Form is emitted, e.g.: # `linalg.matmul(lhs, rhs, outs=[out])` creates the following IR: # ``` -# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# %1 = linalg.matmul ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) # outs(%0 : tensor<4x8xf32>) -# -> tensor<4x8xf32> +# -> tensor<4x8xf32> # ``` # b. by setting emit_generic=True, the Generic Op Form is emitted, e.g.: # `linalg.matmul(lhs, rhs, outs=[out], emit_generic=True)` creates the following IR: # ``` -# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} -# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) +# %1 = linalg.generic {indexing_maps = [...], iterator_types = [...]} +# ins(%arg0, %arg1 : tensor<4x16xf32>, tensor<16x8xf32>) # outs(%0 : tensor<4x8xf32>) { -# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): +# ^bb0(%arg2: f32, %arg3: f32, %arg4: f32): # ... # linalg.yield %3 : f32 -# } -> tensor<4x8xf32> +# } -> tensor<4x8xf32> # ``` # 3. PyRepr -> Runtime Custom Op definitions: directly generates a # linalg.generic form like in 2.b. -# !!!WARNING!!!: if one creates a runtime custom op with the same name +# !!!WARNING!!!: if one creates a runtime custom op with the same name # as an existing core named op, step 2. will likely take precedence. -# TODO: guard against surprises and fail create Runtime Custom Ops with +# TODO: guard against surprises and fail create Runtime Custom Ops with # the same name as existing Core Named Ops. from .opdsl.ops.core_named_ops import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index c3cfdfac9..aa44194b5 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -5,7 +5,6 @@ from typing import Dict, List, Sequence, Tuple, Union from .....ir import * -from ....._mlir_libs._mlir.dialects.linalg import fill_builtin_region from .... import linalg from .... import std @@ -173,7 +172,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - fill_builtin_region(named_op.operation) + linalg.fill_builtin_region(named_op.operation) # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps # attribute that the non-yaml path does not. The non-yaml path hardcodes the # indexing_maps in C++ directly. diff --git a/mlir/python/mlir/dialects/quant.py b/mlir/python/mlir/dialects/quant.py index 92990b1c5..bf1fc5f2d 100644 --- a/mlir/python/mlir/dialects/quant.py +++ b/mlir/python/mlir/dialects/quant.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._mlir_libs._mlir.dialects.quant import * +from .._mlir_libs._mlirDialectsQuant import * diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py index 4f6b675ec..769418e04 100644 --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,5 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._sparse_tensor_ops_gen import * -from .._mlir_libs._mlir.dialects.sparse_tensor import * +from .._mlir_libs._mlirDialectsSparseTensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses From b2612265999f6117e98174998c54770915b1c413 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Thu, 6 Jan 2022 23:44:02 -0800 Subject: [PATCH 197/915] Ensure newlines at the end of files (NFC) --- mlir/lib/Bindings/Python/DialectQuant.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 844cbec4e..de042d1fb 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -306,4 +306,4 @@ PYBIND11_MODULE(_mlirDialectsQuant, m) { m.doc() = "MLIR Quantization dialect"; populateDialectQuantSubmodule(m); -} \ No newline at end of file +} From 446b803d49fd849ebcbaaf0ce1152a04a7c40fd2 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 7 Jan 2022 08:59:16 +0000 Subject: [PATCH 198/915] [mlir][OpDSL] Rename `AttributeDef` to `IndexAttrDef`. Renaming `AttributeDef` to `IndexAttrDef` prepares OpDSL to support different kinds of attributes and more closely reflects the purpose of the attribute. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D115237 --- .../linalg/opdsl/lang/comprehension.py | 6 +- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 5 +- .../linalg/opdsl/ops/core_named_ops.py | 203 +++++++++++------- 3 files changed, 128 insertions(+), 86 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 732cacfff..0edd5d13d 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -261,18 +261,17 @@ def to_scalar_expression(self) -> ScalarExpression: return ScalarArg(self.scalar_name).expr() -class AttributeDef: +class IndexAttrDef: """Index Attribute definition. Index attributes provide a way to define and set symbols that can be used in indexing expressions. Every attribute specifies a tuple of symbols that at compile-time are replaced by integer values. """ - yaml_tag = "!LinalgAttributeDef" def __init__(self, *sizes: SymbolDef): if any(not isinstance(size, SymbolDef) for size in sizes): - raise ValueError(f"AttributeDef requires sizes of type SymbolDef but got " + raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got " f"{sizes}") self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) @@ -501,6 +500,7 @@ def __init__(self, cpp_name: str): ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") + class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index a65350ccd..459b1206a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -20,6 +20,7 @@ StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList, Sequence[Union[ir.Value, ir.Operation, ir.OpView]]] + @contextmanager def bind_op_def(model: LinalgOpDef): if hasattr(_CONTEXT, "current_op_def"): @@ -128,12 +129,12 @@ def linalg_structured_op(dsl_func=None, sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, (TensorDef, ScalarDef, AttributeDef)): + if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): tc_model.add_operand(param_name, param_default.operand_def) else: raise ValueError( f"@linalg_structured_op function parameters must be defaulted as " - f"TensorDef(...), ScalarDef(...), or AttributeDef(...): " + f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 85bed25fe..1e78e624f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -20,6 +20,7 @@ def matmul( implements(ContractionOpInterface) C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + @linalg_structured_op def matmul_unsigned( A=TensorDef(T1, S.M, S.K), @@ -34,6 +35,7 @@ def matmul_unsigned( implements(ContractionOpInterface) C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), @@ -49,13 +51,15 @@ def quantized_matmul( matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.k, D.n]) - cast(U, BZp)) + C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * ( + cast(U, B[D.k, D.n]) - cast(U, BZp)) + @linalg_structured_op -def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), - accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, - output=True)): +def mmt4d( + lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): """Performs a matrix-matrix-transpose multiplication of two 4D inputs. Differences from linalg.matmul: @@ -68,7 +72,10 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), """ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + accum[D.m, D.n, D.m0, + D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast( + TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + @linalg_structured_op def batch_matmul( @@ -84,6 +91,7 @@ def batch_matmul( implements(ContractionOpInterface) C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) + @linalg_structured_op def quantized_batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), @@ -99,7 +107,8 @@ def quantized_batch_matmul( matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * (cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) + C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * ( + cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) @linalg_structured_op @@ -136,7 +145,7 @@ def vecmat( def batch_matvec( A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K), - C=TensorDef(U, Batch, S.M, output=True)): + C=TensorDef(U, Batch, S.M, output=True)): """Performs a batched matrix-vector multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -158,6 +167,7 @@ def dot( implements(ContractionOpInterface) C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + @linalg_structured_op def conv_1d( I=TensorDef(T1, S.OW + S.KW), @@ -170,8 +180,8 @@ def conv_1d( """ implements(ConvolutionOpInterface) domain(D.ow, D.kw) - O[D.ow] += cast( - U, I[D.ow + D.kw]) * cast(U, K[D.kw]) + O[D.ow] += cast(U, I[D.ow + D.kw]) * cast(U, K[D.kw]) + @linalg_structured_op def conv_2d( @@ -185,8 +195,8 @@ def conv_2d( """ implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += cast( - U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) + O[D.oh, D.ow] += cast(U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) + @linalg_structured_op def conv_3d( @@ -200,16 +210,18 @@ def conv_3d( """ implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) - O[D.od, D.oh, D.ow] += cast( - U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kd, D.kh, D.kw]) + O[D.od, D.oh, + D.ow] += cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast( + U, K[D.kd, D.kh, D.kw]) + @linalg_structured_op def conv_1d_nwc_wcf( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=AttributeDef(S.SW), - dilations=AttributeDef(S.DW)): + strides=IndexAttrDef(S.SW), + dilations=IndexAttrDef(S.DW)): """Performs 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -217,17 +229,18 @@ def conv_1d_nwc_wcf( """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, D.f] += cast( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c - ]) * cast(U, K[D.kw, D.c, D.f]) + O[D.n, D.ow, D.f] += cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * cast( + U, K[D.kw, D.c, D.f]) + @linalg_structured_op def conv_2d_nhwc_hwcf( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs 2-D convolution. Layout: @@ -240,18 +253,20 @@ def conv_2d_nhwc_hwcf( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.f] += cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c - ]) * cast(U, K[D.kh, D.kw, D.c, D.f]) + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c]) * cast(U, K[D.kh, D.kw, D.c, D.f]) + @linalg_structured_op def conv_2d_nhwc_hwcf_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs 2-D convolution with zero point offsets. Layout: @@ -264,17 +279,21 @@ def conv_2d_nhwc_hwcf_q( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += (cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c - ]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) + O[D.n, D.oh, D.ow, + D.f] += (cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - + cast(U, IZp)) * ( + cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) + @linalg_structured_op def conv_2d_nchw_fchw( - I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KH, S.KW), O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs 2-D convolution. Layout: @@ -287,17 +306,18 @@ def conv_2d_nchw_fchw( implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.f, D.oh, D.ow] += cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW - ]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op def conv_3d_ndhwc_dhwcf( - I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=AttributeDef(S.SD, S.SH, S.SW), - dilations=AttributeDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW), + dilations=IndexAttrDef(S.DD, S.DH, S.DW)): """Performs 3-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -306,16 +326,18 @@ def conv_3d_ndhwc_dhwcf( implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.f] += cast( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c - ]) * cast(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c]) * cast( + U, K[D.kd, D.kh, D.kw, D.c, D.f]) + @linalg_structured_op def depthwise_conv_1d_nwc_wc( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KW, S.IC), O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=AttributeDef(S.SW), - dilations=AttributeDef(S.DW)): + strides=IndexAttrDef(S.SW), + dilations=IndexAttrDef(S.DW)): """Performs depth-wise 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -328,13 +350,15 @@ def depthwise_conv_1d_nwc_wc( cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ cast(U, K[D.kw, D.ic]) + @linalg_structured_op def depthwise_conv_2d_nhwc_hwc( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -347,15 +371,17 @@ def depthwise_conv_2d_nhwc_hwc( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.kh, D.kw, D.ic]) + @linalg_structured_op def depthwise_conv_2d_nhwc_hwc_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -368,13 +394,15 @@ def depthwise_conv_2d_nhwc_hwc_q( D.ic]) - cast(U, IZp)) * (cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp))) + @linalg_structured_op def depthwise_conv_2d_nhwc_hwcm( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -386,15 +414,17 @@ def depthwise_conv_2d_nhwc_hwcm( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) + @linalg_structured_op def depthwise_conv_2d_nhwc_hwcm_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -410,11 +440,12 @@ def depthwise_conv_2d_nhwc_hwcm_q( @linalg_structured_op def pooling_nhwc_sum( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -428,11 +459,12 @@ def pooling_nhwc_sum( @linalg_structured_op def pooling_nhwc_max( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -444,13 +476,15 @@ def pooling_nhwc_max( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_nhwc_max_unsigned( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs unsigned max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -462,13 +496,15 @@ def pooling_nhwc_max_unsigned( cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_nchw_max( - I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -477,16 +513,18 @@ def pooling_nchw_max( implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - ])) + cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW,])) + @linalg_structured_op def pooling_nhwc_min( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -498,13 +536,15 @@ def pooling_nhwc_min( cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_nhwc_min_unsigned( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, + S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SH, S.SW), - dilations=AttributeDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW), + dilations=IndexAttrDef(S.DH, S.DW)): """Performs unsigned min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -516,14 +556,15 @@ def pooling_nhwc_min_unsigned( cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + @linalg_structured_op def pooling_ndhwc_sum( I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SD, S.SH, S.SW), - dilations=AttributeDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW), + dilations=IndexAttrDef(S.DD, S.DH, S.DW)): """Performs 3D sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -542,8 +583,8 @@ def pooling_ndhwc_max( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SD, S.SH, S.SW), - dilations=AttributeDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW), + dilations=IndexAttrDef(S.DD, S.DH, S.DW)): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -563,8 +604,8 @@ def pooling_ndhwc_min( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=AttributeDef(S.SD, S.SH, S.SW), - dilations=AttributeDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW), + dilations=IndexAttrDef(S.DD, S.DH, S.DW)): """Performs 3D min pooling. Numeric casting is performed on the input operand, promoting it to the same From fb1adbe4e75d5f9a043076795734f6e4182d8d78 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Fri, 7 Jan 2022 13:22:32 +0100 Subject: [PATCH 199/915] [mlir][python] Use a named object Currently, the object would be immediately destroyed after creation. Found by ClangTidy bugprone-unused-raii. --- mlir/lib/Bindings/Python/PybindUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 2fdcf695b..457d8090c 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -140,7 +140,7 @@ class PyFileAccumulator { MlirStringCallback getCallback() { return [](MlirStringRef part, void *userData) { - pybind11::gil_scoped_acquire(); + pybind11::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. From c1a8058e2955a751a94d759c1580a2c53fc5a925 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 7 Jan 2022 12:23:11 +0000 Subject: [PATCH 200/915] [mlir][OpDSL] Add `TypeFn` class. This revision introduces a the `TypeFn` class that similar to the `PrimFn` class contains an extensible set of type conversion functions. Having the same mechanism for both type conversion functions and arithmetic functions improves code consistency. Additionally, having an explicit function class and function name is a prerequisite to specify a conversion or arithmetic function via attribute. In a follow up commits, we will introduce function attributes to make OpDSL operations more generic. In particular, the goal is to handle signed and unsigned computation in one operations. Today, there is a linalg.matmul and a linalg.matmul_unsigned. The commit implements the following changes: - Introduce the class of type conversion functions `TypeFn` - Replace the hardwired cast and cast_unsigned ops by the `TypeFn` counterparts - Adapt the python and C++ code generation paths to support the new cast operations Example: ``` cast(U, A[D.m, D.k]) ``` changes to ``` TypeFn.cast(U, A[D.m, D.k]) ``` Depends On D115237 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115239 --- .../linalg/opdsl/lang/comprehension.py | 83 +++++++---- .../dialects/linalg/opdsl/lang/emitter.py | 55 ++++--- .../dialects/linalg/opdsl/lang/scalar_expr.py | 68 ++++----- .../linalg/opdsl/ops/core_named_ops.py | 135 ++++++++++-------- 4 files changed, 193 insertions(+), 148 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 0edd5d13d..be7fc02d0 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -314,6 +314,39 @@ def __repr__(self): return f"{defs_repr} = {values_repr}" +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, + arg: TensorExpression) -> "TensorTypeFn": + return TensorTypeFn(self, type_var, arg) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") + + class PrimFnType: """Primitive operations.""" @@ -391,6 +424,26 @@ def __repr__(self): return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" +class TensorTypeFn(TensorExpression): + """Application of a type conversion function.""" + + def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression): + self.type_fn = type_fn + self.type_var = type_var + self.arg = arg + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + self.arg.to_scalar_expression()).expr() + + def visit_tensor_exprs(self, callback): + super().visit_tensor_exprs(callback) + self.arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.type_fn)}({type_var}, {self.arg})" + + class const(TensorExpression): """Returns the given constant floating point or integer value.""" @@ -433,36 +486,6 @@ def __repr__(self): return f"index({repr(self.dim)})" -class cast(TensorExpression): - """Casts the element type to a type (typically symbolic TypeVar).""" - - def __init__(self, to_type: TypeVar, operand: TensorExpression): - self.to_type = to_type - self.operand = operand - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), - False).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - self.operand.visit_tensor_exprs(callback) - - def __repr__(self): - return f"cast({self.to_type}, {repr(self.operand)})" - - -class cast_unsigned(cast): - """Casts the element type to an unsigned type (typically symbolic TypeVar).""" - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(), - True).expr() - - def __repr__(self): - return f"cast_unsigned({self.to_type}, {repr(self.operand)})" - - class ReduceApply(TensorExpression): """Application of a reduction. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index aa44194b5..df91b9670 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -2,7 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Dict, List, Sequence, Tuple, Union +from typing import Callable, Dict, List, Sequence, Tuple, Union from .....ir import * @@ -24,6 +24,7 @@ ValueList = Union[Sequence[Value], OpResultList] + def isa(cls: Type, ty: Type): try: cls(ty) @@ -221,24 +222,38 @@ def expression(self, expr: ScalarExpression) -> Value: IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result elif expr.scalar_apply: - try: - fn = getattr(self, f"_eval_{expr.scalar_apply.fn_name}") - except AttributeError: - raise ValueError( - f"Function '{expr.scalar_apply.fn_name}' is not a known " - "scalar body function") + fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_apply.operands ] return fn(*operand_values) - elif expr.symbolic_cast: - operand_value = self.expression(expr.symbolic_cast.operand) - return self.cast(expr.symbolic_cast.to_type.name, operand_value, - expr.symbolic_cast.is_unsigned_cast) + elif expr.type_fn: + fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}") + operand = self.expression(expr.type_fn.operand) + return fn(expr.type_fn.type_var.name, operand) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - def cast(self, type_var_name: str, operand: Value, - is_unsigned_cast: bool) -> Value: + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError(f"Body assignments do not assign all outputs: " + f"missing '{n}'") + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast(self, + type_var_name: str, + operand: Value, + is_unsigned_cast: bool = False) -> Value: try: to_type = self.type_mapping[type_var_name] except KeyError: @@ -289,15 +304,11 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value, raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def yield_outputs(self, *output_names: str): - output_values = [] - for n in output_names: - try: - output_values.append(self.yield_mapping[n]) - except KeyError: - raise ValueError(f"Body assignments do not assign all outputs: " - f"missing '{n}'") - linalg.YieldOp(output_values) + def _typefn_cast(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) def _eval_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index 6de3333fb..c6b1b3885 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -21,11 +21,11 @@ __all__ = [ "ScalarAssign", "ScalarApplyFn", + "ScalarTypeFn", "ScalarArg", "ScalarConst", "ScalarIndex", "ScalarExpression", - "ScalarSymbolicCast", ] @@ -43,6 +43,22 @@ def __repr__(self): return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" +class ScalarTypeFn: + """A type of ScalarExpression that applies a type conversion function.""" + + def __init__(self, fn_name: str, type_var: TypeVar, + operand: "ScalarExpression"): + self.fn_name = fn_name + self.type_var = type_var + self.operand = operand + + def expr(self) -> "ScalarExpression": + return ScalarExpression(type_fn=self) + + def __repr__(self): + return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})" + + class ScalarArg: """A type of ScalarExpression that references a named argument.""" @@ -82,27 +98,12 @@ def __repr__(self): return f"(ScalarIndex({self.dim})" -class ScalarSymbolicCast: - """A type of ScalarExpression that symbolically casts an operand to a TypeVar.""" - - def __init__(self, to_type: TypeVar, operand: "ScalarExpression", - is_unsigned_cast: bool): - self.to_type = to_type - self.operand = operand - self.is_unsigned_cast = is_unsigned_cast - - def expr(self) -> "ScalarExpression": - return ScalarExpression(symbolic_cast=self) - - def __repr__(self): - return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})" - - class ScalarExpression(YAMLObject): """An expression on scalar values. Can be one of: - ScalarApplyFn + - ScalarTypeFn - ScalarArg - ScalarConst - ScalarIndex @@ -112,19 +113,19 @@ class ScalarExpression(YAMLObject): def __init__(self, scalar_apply: Optional[ScalarApplyFn] = None, + type_fn: Optional[ScalarTypeFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, - scalar_index: Optional[ScalarIndex] = None, - symbolic_cast: Optional[ScalarSymbolicCast] = None): - if (bool(scalar_apply) + bool(scalar_arg) + bool(scalar_const) + - bool(scalar_index) + bool(symbolic_cast)) != 1: - raise ValueError("One of 'scalar_apply', 'scalar_arg', 'scalar_const', " - "'scalar_index', 'symbolic_cast' must be specified") + scalar_index: Optional[ScalarIndex] = None): + if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) + + bool(scalar_const) + bool(scalar_index)) != 1: + raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', " + "'scalar_const', 'scalar_index', must be specified") self.scalar_apply = scalar_apply + self.type_fn = type_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index - self.symbolic_cast = symbolic_cast def to_yaml_custom_dict(self): if self.scalar_apply: @@ -133,21 +134,22 @@ def to_yaml_custom_dict(self): fn_name=self.scalar_apply.fn_name, operands=list(self.scalar_apply.operands), )) + if self.type_fn: + # Note that even though operands must be arity 1, we write it the + # same way as for apply because it allows handling code to be more + # generic vs having a special form. + return dict( + type_fn=dict( + fn_name=self.type_fn.fn_name, + type_var=self.type_fn.type_var.name, + operands=[self.type_fn.operand], + )) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: return dict(scalar_const=self.scalar_const.value) elif self.scalar_index: return dict(scalar_index=self.scalar_index.dim) - elif self.symbolic_cast: - # Note that even though operands must be arity 1, we write it the - # same way as for apply because it allows handling code to be more - # generic vs having a special form. - return dict( - symbolic_cast=dict( - type_var=self.symbolic_cast.to_type.name, - operands=[self.symbolic_cast.operand], - is_unsigned_cast=self.symbolic_cast.is_unsigned_cast)) else: raise ValueError(f"Unexpected ScalarExpression type: {self}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 1e78e624f..173af1a3f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -18,7 +18,7 @@ def matmul( """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) @linalg_structured_op @@ -33,7 +33,8 @@ def matmul_unsigned( """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n]) @linalg_structured_op @@ -51,8 +52,8 @@ def quantized_matmul( matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += (cast(U, A[D.m, D.k]) - cast(U, AZp)) * ( - cast(U, B[D.k, D.n]) - cast(U, BZp)) + C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * ( + TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp)) @linalg_structured_op @@ -72,9 +73,9 @@ def mmt4d( """ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, - D.n0] += cast(TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * cast( - TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast( + TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op @@ -89,7 +90,8 @@ def batch_matmul( """ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.b, D.m, D.n] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k, D.n]) + C[D.b, D.m, + D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n]) @linalg_structured_op @@ -107,8 +109,9 @@ def quantized_batch_matmul( matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += (cast(U, A[D.b, D.m, D.k]) - cast(U, AZp)) * ( - cast(U, B[D.b, D.k, D.n]) - cast(U, BZp)) + C[D.b, D.m, + D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * ( + TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp)) @linalg_structured_op @@ -123,7 +126,7 @@ def matvec( """ domain(D.m, D.n) implements(ContractionOpInterface) - x[D.m] += cast(U, A[D.m, D.n]) * cast(U, y[D.n]) + x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n]) @linalg_structured_op @@ -138,7 +141,7 @@ def vecmat( """ domain(D.n, D.m) implements(ContractionOpInterface) - x[D.n] += cast(U, y[D.m]) * cast(U, A[D.m, D.n]) + x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n]) @linalg_structured_op @@ -153,7 +156,7 @@ def batch_matvec( """ domain(D.b, D.m, D.k) implements(ContractionOpInterface) - C[D.b, D.m] += cast(U, A[D.b, D.m, D.k]) * cast(U, B[D.b, D.k]) + C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k]) @linalg_structured_op @@ -165,7 +168,7 @@ def dot( them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[None] += cast(U, A[D.m]) * cast(U, B[D.m]) + C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) @linalg_structured_op @@ -180,7 +183,7 @@ def conv_1d( """ implements(ConvolutionOpInterface) domain(D.ow, D.kw) - O[D.ow] += cast(U, I[D.ow + D.kw]) * cast(U, K[D.kw]) + O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw]) @linalg_structured_op @@ -195,7 +198,8 @@ def conv_2d( """ implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += cast(U, I[D.oh + D.kh, D.ow + D.kw]) * cast(U, K[D.kh, D.kw]) + O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast( + U, K[D.kh, D.kw]) @linalg_structured_op @@ -211,8 +215,8 @@ def conv_3d( implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) O[D.od, D.oh, - D.ow] += cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * cast( - U, K[D.kd, D.kh, D.kw]) + D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + + D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw]) @linalg_structured_op @@ -229,8 +233,9 @@ def conv_1d_nwc_wcf( """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, D.f] += cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * cast( - U, K[D.kw, D.c, D.f]) + O[D.n, D.ow, + D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, + D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f]) @linalg_structured_op @@ -252,9 +257,9 @@ def conv_2d_nhwc_hwcf( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += cast( + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * cast(U, K[D.kh, D.kw, D.c, D.f]) + D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op @@ -280,10 +285,10 @@ def conv_2d_nhwc_hwcf_q( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, - D.f] += (cast( + D.f] += (TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - cast(U, IZp)) * ( - cast(U, K[D.kh, D.kw, D.c, D.f]) - cast(U, KZp)) + TypeFn.cast(U, IZp)) * ( + TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp)) @linalg_structured_op @@ -305,9 +310,9 @@ def conv_2d_nchw_fchw( """ implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += cast( + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * cast(U, K[D.f, D.c, D.kh, D.kw]) + D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op @@ -325,9 +330,9 @@ def conv_3d_ndhwc_dhwcf( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += cast( + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * cast( + D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast( U, K[D.kd, D.kh, D.kw, D.c, D.f]) @@ -347,8 +352,8 @@ def depthwise_conv_1d_nwc_wc( implements(ConvolutionOpInterface) domain(D.n, D.ow, D.ic, D.kw) O[D.n, D.ow, D.ic] += \ - cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - cast(U, K[D.kw, D.ic]) + TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + TypeFn.cast(U, K[D.kw, D.ic]) @linalg_structured_op @@ -367,9 +372,9 @@ def depthwise_conv_2d_nhwc_hwc( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += cast( + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * cast(U, K[D.kh, D.kw, D.ic]) + D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op @@ -389,10 +394,11 @@ def depthwise_conv_2d_nhwc_hwc_q( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += ( - (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) - cast(U, IZp)) * - (cast(U, K[D.kh, D.kw, D.ic]) - cast(U, KZp))) + O[D.n, D.oh, D.ow, + D.ic] += ((TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast(U, IZp)) * + (TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp))) @linalg_structured_op @@ -410,9 +416,9 @@ def depthwise_conv_2d_nhwc_hwcm( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += cast( + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * cast(U, K[D.kh, D.kw, D.ic, D.cm]) + D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op @@ -432,10 +438,11 @@ def depthwise_conv_2d_nhwc_hwcm_q( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += ( - (cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) - cast(U, IZp)) * - (cast(U, K[D.kh, D.kw, D.ic, D.cm]) - cast(U, KZp))) + O[D.n, D.oh, D.ow, D.ic, + D.cm] += ((TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast(U, IZp)) * + (TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp))) @linalg_structured_op @@ -453,7 +460,7 @@ def pooling_nhwc_sum( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] += cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -473,8 +480,8 @@ def pooling_nhwc_max( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -493,7 +500,7 @@ def pooling_nhwc_max_unsigned( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -513,8 +520,9 @@ def pooling_nchw_max( implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( - cast(U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW,])) + TypeFn.cast( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW,])) @linalg_structured_op @@ -533,8 +541,8 @@ def pooling_nhwc_min( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( - cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c])) + TypeFn.cast( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -553,7 +561,7 @@ def pooling_nhwc_min_unsigned( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( - cast_unsigned( + TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -572,7 +580,7 @@ def pooling_ndhwc_sum( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] += cast( + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -593,7 +601,7 @@ def pooling_ndhwc_max( implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( - cast( + TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -614,7 +622,7 @@ def pooling_ndhwc_min( implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( - cast( + TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -636,14 +644,15 @@ def fill_rng_2d( the range of the generated random numbers. """ domain(D.m, D.n) - multiplier = cast(I32, const(1103515245)) - increment = cast(I32, const(12345)) - rand1 = (cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = cast(F64, const(2.3283064e-10)) - offset = cast(F64, const(2147483647)) + multiplier = TypeFn.cast(I32, const(1103515245)) + increment = TypeFn.cast(I32, const(12345)) + rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast(F64, const(2.3283064e-10)) + offset = TypeFn.cast(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = cast(T, (offset + cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast(T, + (offset + TypeFn.cast(F64, rand2)) * scaling + min) @linalg_structured_op @@ -656,4 +665,4 @@ def soft_plus_2d( """ domain(D.m, D.n) O[D.m, D.n] = \ - PrimFn.log(cast(U, const(1.0)) + PrimFn.exp(cast(U, I[D.m, D.n]))) + PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n]))) From 3d2974137d9d790fa4bbd32a80e314372d41bc3d Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 7 Jan 2022 12:37:52 +0000 Subject: [PATCH 201/915] [mlir][OpDSL] Rename `PrimFn` to `ArithFn`. The revision renames `PrimFn` to `ArithFn`. The name resembles the newly introduced arith dialect that implements most of the arithmetic functions. An exception are log/exp that are part of the math dialect. Depends On D115239 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115240 --- .../linalg/opdsl/lang/comprehension.py | 91 +++++++++++-------- .../dialects/linalg/opdsl/lang/emitter.py | 24 ++--- .../dialects/linalg/opdsl/lang/scalar_expr.py | 30 +++--- .../linalg/opdsl/ops/core_named_ops.py | 2 +- 4 files changed, 80 insertions(+), 67 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index be7fc02d0..fd0fa7266 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -77,13 +77,13 @@ def visit_scalar_def(expr): self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return PrimFn.add(self, rhs) + return ArithFn.add(self, rhs) def __mul__(self, rhs) -> "TensorExpression": - return PrimFn.mul(self, rhs) + return ArithFn.mul(self, rhs) def __sub__(self, rhs) -> "TensorExpression": - return PrimFn.sub(self, rhs) + return ArithFn.sub(self, rhs) def __hash__(self): return hash(id(self)) @@ -347,42 +347,55 @@ class TypeFn: cast_unsigned = TypeFnType("cast_unsigned") -class PrimFnType: - """Primitive operations.""" +class ArithFnType: + """Arithmetic function. - def __init__(self, prim_name: str): - self.prim_name = prim_name + An arithmetic function takes one ore more tensor expressions and returns the + function evaluation result. + """ - def __call__(self, *args): - return PrimApply(self, args) + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, *args) -> "TensorArithFn": + return TensorArithFn(self, args) def reduce(self, *reduce_dims: DimDef): - """Shortcut to create a Reduce operation from this primitive.""" + """Shortcut to create a Reduce operation from this function.""" return ReduceFnType(self, *reduce_dims) def __repr__(self): - return f"{self.prim_name}" + return f"{self.fn_name}" -class PrimFn: - add = PrimFnType("add") - exp = PrimFnType("exp") - log = PrimFnType("log") - mul = PrimFnType("mul") - max = PrimFnType("max") - min = PrimFnType("min") - sub = PrimFnType("sub") - max_unsigned = PrimFnType("max_unsigned") - min_unsigned = PrimFnType("min_unsigned") +class ArithFn: + """Arithmetic function namespace. + + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ + add = ArithFnType("add") + exp = ArithFnType("exp") + log = ArithFnType("log") + mul = ArithFnType("mul") + max = ArithFnType("max") + min = ArithFnType("min") + sub = ArithFnType("sub") + max_unsigned = ArithFnType("max_unsigned") + min_unsigned = ArithFnType("min_unsigned") class ReduceFnType: """A reduction operator that reduces into its LHS from its RHS.""" - def __init__(self, operator: PrimFnType, *reduce_dims: DimDef): - """Initializes the ReduceFn with a primitive function and dims.""" - if not isinstance(operator, PrimFnType): - raise ValueError(f"Reduce expected a Prim operator but got {operator}") + def __init__(self, operator: ArithFnType, *reduce_dims: DimDef): + """Initializes the ReduceFn with an airthmetic function and dims.""" + if not isinstance(operator, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {operator}") self.operator = operator self.reduce_dims = tuple(reduce_dims) @@ -390,28 +403,28 @@ def __call__(self, *args: TensorExpression): return ReduceApply(self, args) def __repr__(self): - return (f"reduce_{self.operator.prim_name}" + return (f"reduce_{self.operator.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFn: - add = PrimFn.add.reduce - mul = PrimFn.mul.reduce - max = PrimFn.max.reduce - min = PrimFn.min.reduce - max_unsigned = PrimFn.max_unsigned.reduce - min_unsigned = PrimFn.min_unsigned.reduce + add = ArithFn.add.reduce + mul = ArithFn.mul.reduce + max = ArithFn.max.reduce + min = ArithFn.min.reduce + max_unsigned = ArithFn.max_unsigned.reduce + min_unsigned = ArithFn.min_unsigned.reduce -class PrimApply(TensorExpression): - """Application of a primitive.""" +class TensorArithFn(TensorExpression): + """Application of an arithmetic function.""" - def __init__(self, prim: PrimFnType, args: Sequence[TensorExpression]): - self.prim = prim + def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]): + self.arith_fn = arith_fn self.args = tuple(args) def to_scalar_expression(self) -> ScalarExpression: - return ScalarApplyFn(self.prim.prim_name, + return ScalarArithFn(self.arith_fn.fn_name, *[arg.to_scalar_expression() for arg in self.args ]).expr() @@ -421,7 +434,7 @@ def visit_tensor_exprs(self, callback): arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.prim)}({', '.join(repr(a) for a in self.args)})" + return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" class TensorTypeFn(TensorExpression): @@ -503,7 +516,7 @@ def to_scalar_expression(self) -> ScalarExpression: f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarApplyFn(self.reduce.operator.prim_name, *full_args).expr() + return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr() def visit_tensor_exprs(self, callback): for arg in self.args: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index df91b9670..22568c8b6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -221,10 +221,10 @@ def expression(self, expr: ScalarExpression) -> Value: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_apply: - fn = self._get_function(f"_eval_{expr.scalar_apply.fn_name}") + elif expr.arith_fn: + fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}") operand_values = [ - self.expression(operand) for operand in expr.scalar_apply.operands + self.expression(operand) for operand in expr.arith_fn.operands ] return fn(*operand_values) elif expr.type_fn: @@ -310,59 +310,59 @@ def _typefn_cast(self, type_var_name: str, operand: Value) -> Value: def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) - def _eval_add(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.AddIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operand: {lhs}") - def _eval_exp(self, x: Value) -> Value: + def _arithfn_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") - def _eval_log(self, x: Value) -> Value: + def _arithfn_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") - def _eval_sub(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operand: {lhs}") - def _eval_mul(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operand: {lhs}") - def _eval_max(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operand: {lhs}") - def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") - def _eval_min(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operand: {lhs}") - def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index c6b1b3885..2a30e6e78 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -20,7 +20,7 @@ __all__ = [ "ScalarAssign", - "ScalarApplyFn", + "ScalarArithFn", "ScalarTypeFn", "ScalarArg", "ScalarConst", @@ -29,18 +29,18 @@ ] -class ScalarApplyFn: - """A type of ScalarExpression that applies a named function to operands.""" +class ScalarArithFn: + """A type of ScalarExpression that applies an arithmetic function.""" def __init__(self, fn_name: str, *operands: "ScalarExpression"): self.fn_name = fn_name self.operands = operands def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_apply=self) + return ScalarExpression(arith_fn=self) def __repr__(self): - return f"ScalarApplyFn<{self.fn_name}>({', '.join(self.operands)})" + return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})" class ScalarTypeFn: @@ -102,7 +102,7 @@ class ScalarExpression(YAMLObject): """An expression on scalar values. Can be one of: - - ScalarApplyFn + - ScalarArithFn - ScalarTypeFn - ScalarArg - ScalarConst @@ -112,27 +112,27 @@ class ScalarExpression(YAMLObject): yaml_tag = "!ScalarExpression" def __init__(self, - scalar_apply: Optional[ScalarApplyFn] = None, + arith_fn: Optional[ScalarArithFn] = None, type_fn: Optional[ScalarTypeFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None): - if (bool(scalar_apply) + bool(type_fn) + bool(scalar_arg) + - bool(scalar_const) + bool(scalar_index)) != 1: - raise ValueError("One of 'scalar_apply', 'type_fn', 'scalar_arg', " + if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) + + bool(scalar_index)) != 1: + raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', " "'scalar_const', 'scalar_index', must be specified") - self.scalar_apply = scalar_apply + self.arith_fn = arith_fn self.type_fn = type_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index def to_yaml_custom_dict(self): - if self.scalar_apply: + if self.arith_fn: return dict( - scalar_apply=dict( - fn_name=self.scalar_apply.fn_name, - operands=list(self.scalar_apply.operands), + arith_fn=dict( + fn_name=self.arith_fn.fn_name, + operands=list(self.arith_fn.operands), )) if self.type_fn: # Note that even though operands must be arity 1, we write it the diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 173af1a3f..afc078d50 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -665,4 +665,4 @@ def soft_plus_2d( """ domain(D.m, D.n) O[D.m, D.n] = \ - PrimFn.log(TypeFn.cast(U, const(1.0)) + PrimFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n]))) From ce9473ac0c870e701c8b98562691cb2999ff4d69 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 7 Jan 2022 12:49:44 +0000 Subject: [PATCH 202/915] [mlir][OpDSL] Separate `ReduceFn` and `ReduceFnUse`. The revision distinguishes `ReduceFn` and `ReduceFnUse`. The latter has the reduction dimensions attached while the former specifies the arithmetic function only. This separation allows us to adapt the reduction syntax a little bit and specify the reduction dimensions using square brackets (in contrast to the round brackets used for the values to reduce). It als is a preparation to add reduction function attributes to OpDSL. A reduction function attribute shall only specify the arithmetic function and not the reduction dimensions. Example: ``` ReduceFn.max_unsigned(D.kh, D.kw)(...) ``` changes to: ``` ReduceFn.max_unsigned[D.kh, D.kw](...) ``` Depends On D115240 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D115241 --- .../linalg/opdsl/lang/comprehension.py | 84 +++++++++++-------- .../linalg/opdsl/ops/core_named_ops.py | 14 ++-- 2 files changed, 56 insertions(+), 42 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index fd0fa7266..ddbebb29f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -43,8 +43,8 @@ def visit_affine_exprs(expr): if isinstance(expr, TensorUse): for ind in expr.indices: ind.visit_affine_exprs(visit_dim_def) - if isinstance(expr, ReduceApply): - for ind in expr.reduce.reduce_dims: + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: ind.visit_affine_exprs(visit_dim_def) self.visit_tensor_exprs(visit_affine_exprs) @@ -114,8 +114,8 @@ def tensor_name(self) -> str: assert name is not None, "TensorDef not attached" return name - def __iadd__(self, rhs: TensorExpression) -> TensorExpression: - return ReduceFn.add(*self._compute_reduce_dims(rhs))(rhs) + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: """For implicit reductions, computes default reduction dims. @@ -285,7 +285,7 @@ def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): # Find the lhs to reduction rhs. for assign, value in bindings: - if isinstance(value, ReduceApply): + if isinstance(value, TensorReduceFn): if value.lhs: raise ValueError(f"Reduction expression already assigns: {value}") value.lhs = assign @@ -297,8 +297,8 @@ def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: """Gets the reduction dims for the comprehension or None.""" result = set() for use in self.values: - if isinstance(use, ReduceApply): - result.add(use.reduce.reduce_dims) + if isinstance(use, TensorReduceFn): + result.add(use.reduce_use.reduce_dims) else: result.add(tuple()) return result @@ -360,10 +360,6 @@ def __init__(self, fn_name: str): def __call__(self, *args) -> "TensorArithFn": return TensorArithFn(self, args) - def reduce(self, *reduce_dims: DimDef): - """Shortcut to create a Reduce operation from this function.""" - return ReduceFnType(self, *reduce_dims) - def __repr__(self): return f"{self.fn_name}" @@ -389,31 +385,49 @@ class ArithFn: min_unsigned = ArithFnType("min_unsigned") -class ReduceFnType: - """A reduction operator that reduces into its LHS from its RHS.""" +class ReduceFnUse: + """Reduction function use. + + A reduction use specifies the reduction function and dimensions. + """ - def __init__(self, operator: ArithFnType, *reduce_dims: DimDef): - """Initializes the ReduceFn with an airthmetic function and dims.""" - if not isinstance(operator, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {operator}") - self.operator = operator - self.reduce_dims = tuple(reduce_dims) + def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): + self.arith_fn = arith_fn + self.reduce_dims = reduce_dims def __call__(self, *args: TensorExpression): - return ReduceApply(self, args) + return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.operator.fn_name}" + return (f"reduce_{self.arith_fn.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") +class ReduceFnType: + """Reduction function. + + An arithmetic function that reduces its RHS into its LHS. + """ + + def __init__(self, arith_fn: ArithFnType): + if not isinstance(arith_fn, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") + self.arith_fn = arith_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.arith_fn, *reduce_dims) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}") + + class ReduceFn: - add = ArithFn.add.reduce - mul = ArithFn.mul.reduce - max = ArithFn.max.reduce - min = ArithFn.min.reduce - max_unsigned = ArithFn.max_unsigned.reduce - min_unsigned = ArithFn.min_unsigned.reduce + add = ReduceFnType(ArithFn.add) + mul = ReduceFnType(ArithFn.mul) + max = ReduceFnType(ArithFn.max) + min = ReduceFnType(ArithFn.min) + max_unsigned = ReduceFnType(ArithFn.max_unsigned) + min_unsigned = ReduceFnType(ArithFn.min_unsigned) class TensorArithFn(TensorExpression): @@ -499,31 +513,31 @@ def __repr__(self): return f"index({repr(self.dim)})" -class ReduceApply(TensorExpression): - """Application of a reduction. +class TensorReduceFn(TensorExpression): + """Application of a reduction function. - This captures the lhs separately (initial value) separately from the rhs. + This captures the lhs (initial value) separately from the rhs. """ - def __init__(self, reduce: ReduceFnType, args: Sequence[TensorExpression]): - self.reduce = reduce + def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]): + self.reduce_use = reduce_use self.lhs = None # type: Optional[TensorUse] self.args = tuple(args) def to_scalar_expression(self) -> ScalarExpression: if self.lhs is None: - raise ValueError(f"Cannot scalarize a ReduceApply that has not been " + raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce.operator.fn_name, *full_args).expr() + return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() def visit_tensor_exprs(self, callback): for arg in self.args: arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.reduce)}({', '.join(repr(a) for a in self.args)})" + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" class OpInterfaceDef: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index afc078d50..9fe370ffd 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -479,7 +479,7 @@ def pooling_nhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -499,7 +499,7 @@ def pooling_nhwc_max_unsigned( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -519,7 +519,7 @@ def pooling_nchw_max( """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max(D.kh, D.kw)( + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) @@ -540,7 +540,7 @@ def pooling_nhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -560,7 +560,7 @@ def pooling_nhwc_min_unsigned( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -600,7 +600,7 @@ def pooling_ndhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max(D.kd, D.kh, D.kw)( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -621,7 +621,7 @@ def pooling_ndhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min(D.kd, D.kh, D.kw)( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) From 11370fd22287dd1ef828336b371fda53617fb229 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 8 Jan 2022 20:37:39 +0000 Subject: [PATCH 203/915] Apply clang-tidy fixes for modernize-use-equals-default in IRCore.cpp (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1a9604882..729431b98 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -725,7 +725,7 @@ PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, py::object callback) : context(context), callback(std::move(callback)) {} -PyDiagnosticHandler::~PyDiagnosticHandler() {} +PyDiagnosticHandler::~PyDiagnosticHandler() = default; void PyDiagnosticHandler::detach() { if (!registeredID) From ac9619df544c047fb91b83a312893a4208fa7c4e Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Tue, 11 Jan 2022 17:04:39 +0000 Subject: [PATCH 204/915] [mlir][linalg] Improve pooling op iterator order consistency All named ops list iterators for accessing output first except pooling ops. This commit made the pooling ops consistent with the rest. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D115520 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9fe370ffd..d3651bd76 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -459,7 +459,7 @@ def pooling_nhwc_sum( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -478,7 +478,7 @@ def pooling_nhwc_max( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -498,7 +498,7 @@ def pooling_nhwc_max_unsigned( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -539,7 +539,7 @@ def pooling_nhwc_min( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -559,7 +559,7 @@ def pooling_nhwc_min_unsigned( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -579,7 +579,7 @@ def pooling_ndhwc_sum( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -599,7 +599,7 @@ def pooling_ndhwc_max( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, @@ -620,7 +620,7 @@ def pooling_ndhwc_min( data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.c) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw]( TypeFn.cast( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, From 71e8e8c45d480c00a2b798308082cb63fe249c44 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 12 Jan 2022 11:20:18 -0800 Subject: [PATCH 205/915] [mlir] Finish removing Identifier from the C++ API There have been a few API pieces remaining to allow for a smooth transition for downstream users, but these have been up for a few months now. After this only the C API will have reference to "Identifier", but those will be reworked in a followup. The main updates are: * Identifier -> StringAttr * StringAttr::get requires the context as the first parameter - i.e. `Identifier::get("...", ctx)` -> `StringAttr::get(ctx, "...")` Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D116626 --- mlir/lib/CAPI/IR/IR.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 955f5e0c1..13b7673e2 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -820,8 +820,8 @@ MlirLogicalResult mlirSymbolTableReplaceAllSymbolUses(MlirStringRef oldSymbol, MlirOperation from) { auto *cppFrom = unwrap(from); auto *context = cppFrom->getContext(); - auto oldSymbolAttr = StringAttr::get(unwrap(oldSymbol), context); - auto newSymbolAttr = StringAttr::get(unwrap(newSymbol), context); + auto oldSymbolAttr = StringAttr::get(context, unwrap(oldSymbol)); + auto newSymbolAttr = StringAttr::get(context, unwrap(newSymbol)); return wrap(SymbolTable::replaceAllSymbolUses(oldSymbolAttr, newSymbolAttr, unwrap(from))); } From 48625067dc2fd1f2db6648795b36ee468d707793 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 11 Jan 2022 20:21:37 +0000 Subject: [PATCH 206/915] Apply clang-tidy fixes for bugprone-macro-parentheses in Interop.h (NFC) --- mlir/include/mlir-c/Bindings/Python/Interop.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 7fcfd028b..7efe8500a 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -96,7 +96,8 @@ /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus -#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) const_cast(object.ptr) +#define MLIR_PYTHON_GET_WRAPPED_POINTER(object) \ + (const_cast((object).ptr)) #else #define MLIR_PYTHON_GET_WRAPPED_POINTER(object) (void *)(object.ptr) #endif From f23a7f36a95cc2105094d91bf47bba458669eee5 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Thu, 13 Jan 2022 11:33:42 +0100 Subject: [PATCH 207/915] [mlir] Introduce C API for PDL dialect types This change introduces C API helper functions to work with PDL types. Modification closely follow the format of the https://reviews.llvm.org/D116546. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117221 --- mlir/include/mlir-c/Dialect/PDL.h | 72 +++++++++++++++++++++++ mlir/include/mlir-c/Dialect/Quant.h | 2 +- mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++ mlir/lib/CAPI/Dialect/PDL.cpp | 85 ++++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/Quant.cpp | 2 +- 5 files changed, 168 insertions(+), 2 deletions(-) create mode 100644 mlir/include/mlir-c/Dialect/PDL.h create mode 100644 mlir/lib/CAPI/Dialect/PDL.cpp diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h new file mode 100644 index 000000000..5e0a2bc95 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -0,0 +1,72 @@ +//===-- mlir-c/Dialect/PDL.h - C API for PDL Dialect --------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_PDL_H +#define MLIR_C_DIALECT_PDL_H + +#include "mlir-c/IR.h" +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(PDL, pdl); + +//===---------------------------------------------------------------------===// +// PDLType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLType(MlirType type); + +//===---------------------------------------------------------------------===// +// AttributeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLAttributeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLAttributeTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLOperationType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLOperationTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// RangeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType); + +//===---------------------------------------------------------------------===// +// TypeType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLTypeType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLTypeTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// ValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_QUANT_H diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index c45d93af4..eb529b845 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -1,4 +1,4 @@ -//===-- mlir-c/Dialect/LLVM.h - C API for LLVM --------------------*- C -*-===// +//===-- mlir-c/Dialect/Quant.h - C API for LLVM -------------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 1d3e2727a..f66f7b0b8 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -106,3 +106,12 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant MLIRCAPIIR MLIRQuant ) + +add_mlir_upstream_c_api_library(MLIRCAPIPDL + PDL.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRPDL +) diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp new file mode 100644 index 000000000..42b4ec24b --- /dev/null +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -0,0 +1,85 @@ +//===- PDL.cpp - C Interface for PDL dialect ------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect) + +//===---------------------------------------------------------------------===// +// PDLType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLType(MlirType type) { + return unwrap(type).isa(); +} + +//===---------------------------------------------------------------------===// +// AttributeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLAttributeType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { + return wrap(pdl::AttributeType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLOperationType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPDLOperationTypeGet(MlirContext ctx) { + return wrap(pdl::OperationType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// RangeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLRangeType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPDLRangeTypeGet(MlirType elementType) { + return wrap(pdl::RangeType::get(unwrap(elementType))); +} + +//===---------------------------------------------------------------------===// +// TypeType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLTypeType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPDLTypeTypeGet(MlirContext ctx) { + return wrap(pdl::TypeType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// ValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAPDLValueType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirPDLValueTypeGet(MlirContext ctx) { + return wrap(pdl::ValueType::get(unwrap(ctx))); +} diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 3fc00a72d..483536508 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -1,4 +1,4 @@ -//===- LLVM.cpp - C Interface for Quant dialect ---------------------------===// +//===- Quant.cpp - C Interface for Quant dialect --------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From e557fe51b0d6bf6d56b58e18a2dab7d36b293183 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 13 Jan 2022 20:27:30 +0000 Subject: [PATCH 208/915] Apply clang-tidy fixes for llvm-header-guard in MLIR (NFC) Differential Revision: https://reviews.llvm.org/D117251 --- mlir/include/mlir-c/Conversion.h | 6 +++--- mlir/include/mlir-c/Dialect/SparseTensor.h | 6 +++--- mlir/include/mlir-c/ExecutionEngine.h | 2 +- mlir/include/mlir-c/Interfaces.h | 6 +++--- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 6 +++--- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/Conversion.h b/mlir/include/mlir-c/Conversion.h index b69c41710..88c5143ad 100644 --- a/mlir/include/mlir-c/Conversion.h +++ b/mlir/include/mlir-c/Conversion.h @@ -12,11 +12,11 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_CONVERSIONS_H -#define MLIR_C_CONVERSIONS_H +#ifndef MLIR_C_CONVERSION_H +#define MLIR_C_CONVERSION_H #include "mlir-c/Support.h" #include "mlir/Conversion/Passes.capi.h.inc" -#endif // MLIR_C_CONVERSIONS_H +#endif // MLIR_C_CONVERSION_H diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 16d932c16..68d72b917 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_DIALECT_SPARSE_TENSOR_H -#define MLIR_C_DIALECT_SPARSE_TENSOR_H +#ifndef MLIR_C_DIALECT_SPARSETENSOR_H +#define MLIR_C_DIALECT_SPARSETENSOR_H #include "mlir-c/AffineMap.h" #include "mlir-c/Registration.h" @@ -76,4 +76,4 @@ mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr); #include "mlir/Dialect/SparseTensor/Transforms/Passes.capi.h.inc" -#endif // MLIR_C_DIALECT_SPARSE_TENSOR_H +#endif // MLIR_C_DIALECT_SPARSETENSOR_H diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index cd3df8ebf..adb4e823e 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -87,4 +87,4 @@ mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit, } #endif -#endif // EXECUTIONENGINE_H +#endif // MLIR_C_EXECUTIONENGINE_H diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index 878628342..233f828b9 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -12,8 +12,8 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_DIALECT_H -#define MLIR_C_DIALECT_H +#ifndef MLIR_C_INTERFACES_H +#define MLIR_C_INTERFACES_H #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -64,4 +64,4 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( } #endif -#endif // MLIR_C_DIALECT_H +#endif // MLIR_C_INTERFACES_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 811e54ab4..f2ccbfed1 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -15,8 +15,8 @@ // Pybind-based internals of the core libraries). //===----------------------------------------------------------------------===// -#ifndef MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H -#define MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H +#ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H #include #include @@ -423,4 +423,4 @@ class mlir_type_subclass : public pure_subclass { } // namespace python } // namespace mlir -#endif // MLIR_BINDINGS_PYTHON_PYBIND_ADAPTORS_H +#endif // MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H From b92b09046e902d4dc624fb8d49afc0a6c8584ce1 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 01:35:35 +0000 Subject: [PATCH 209/915] Apply clang-tidy fixes for modernize-use-equals-default to MLIR (NFC) --- mlir/lib/Bindings/Python/IRModule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 2f354d6d1..698fc6dc7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -56,7 +56,7 @@ class PyObjectRef { } PyObjectRef(const PyObjectRef &other) : referrent(other.referrent), object(other.object /* copies */) {} - ~PyObjectRef() {} + ~PyObjectRef() = default; int getRefCount() { if (!object) From c11bb8541abe5e02ede138623b4fcaef9156b7dc Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 01:35:38 +0000 Subject: [PATCH 210/915] Apply clang-tidy fixes for modernize-use-override to MLIR (NFC) --- mlir/lib/Bindings/Python/IRModule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 698fc6dc7..4f1b65610 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -496,7 +496,7 @@ class PyOperation; using PyOperationRef = PyObjectRef; class PyOperation : public PyOperationBase, public BaseContextObject { public: - ~PyOperation(); + ~PyOperation() override; PyOperation &getOperation() override { return *this; } /// Returns a PyOperation for the given MlirOperation, optionally associating From 1a2a30ab6bafeaff0b74ec6f4d28e798c4e413b2 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 01:35:39 +0000 Subject: [PATCH 211/915] Apply clang-tidy fixes for modernize-use-using to MLIR (NFC) --- mlir/include/mlir-c/Interfaces.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index 233f828b9..7ab6b8af3 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -48,7 +48,7 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); /// transferring ownership to the caller. The first argument is the number of /// consecutive elements pointed to by the second argument. The third argument /// is an opaque pointer forwarded to the callback by the caller. -typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); +using MlirTypesCallback = void (*)(intptr_t, MlirType *, void *); /// Infers the return types of the operation identified by its canonical given /// the arguments that will be supplied to its generic builder. Calls `callback` From 1d22e101cc6eb75b9d3fec9b853691e856a7c128 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 01:35:55 +0000 Subject: [PATCH 212/915] Apply clang-tidy fixes for performance-unnecessary-value-param to MLIR (NFC) --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 7 ++++--- mlir/include/mlir/CAPI/Utils.h | 4 +++- mlir/lib/Bindings/Python/IRModule.h | 10 ++++++---- mlir/lib/Bindings/Python/PybindUtils.h | 2 +- 4 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index f2ccbfed1..05fb0844a 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -250,7 +250,7 @@ namespace adaptors { class pure_subclass { public: pure_subclass(py::handle scope, const char *derivedClassName, - py::object superClass) { + const py::object &superClass) { py::object pyType = py::reinterpret_borrow((PyObject *)&PyType_Type); py::object metaclass = pyType(superClass); @@ -335,7 +335,8 @@ class mlir_attribute_subclass : public pure_subclass { /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, py::object superClass) + IsAFunctionTy isaFunction, + const py::object &superClass) : pure_subclass(scope, typeClassName, superClass) { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat @@ -386,7 +387,7 @@ class mlir_type_subclass : public pure_subclass { /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, py::object superClass) + IsAFunctionTy isaFunction, const py::object &superClass) : pure_subclass(scope, typeClassName, superClass) { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat diff --git a/mlir/include/mlir/CAPI/Utils.h b/mlir/include/mlir/CAPI/Utils.h index 000516e7a..d78cdbf31 100644 --- a/mlir/include/mlir/CAPI/Utils.h +++ b/mlir/include/mlir/CAPI/Utils.h @@ -14,6 +14,8 @@ #ifndef MLIR_CAPI_UTILS_H #define MLIR_CAPI_UTILS_H +#include + #include "mlir-c/Support.h" #include "llvm/Support/raw_ostream.h" @@ -29,7 +31,7 @@ class CallbackOstream : public llvm::raw_ostream { public: CallbackOstream(std::function callback, void *opaqueData) - : raw_ostream(/*unbuffered=*/true), callback(callback), + : raw_ostream(/*unbuffered=*/true), callback(std::move(callback)), opaqueData(opaqueData), pos(0u) {} void write_impl(const char *ptr, size_t size) override { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 4f1b65610..b1424a994 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -9,6 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include #include #include "PybindUtils.h" @@ -330,8 +331,9 @@ class PyDiagnosticHandler { void detach(); pybind11::object contextEnter() { return pybind11::cast(this); } - void contextExit(pybind11::object excType, pybind11::object excVal, - pybind11::object excTb) { + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { detach(); } @@ -532,7 +534,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } bool isAttached() { return attached; } - void setAttached(pybind11::object parent = pybind11::object()) { + void setAttached(const pybind11::object &parent = pybind11::object()) { assert(!attached && "operation already attached"); attached = true; } @@ -876,7 +878,7 @@ class PyConcreteAttribute : public BaseTy { class PyValue { public: PyValue(PyOperationRef parentOperation, MlirValue value) - : parentOperation(parentOperation), value(value) {} + : parentOperation(std::move(parentOperation)), value(value) {} operator MlirValue() const { return value; } MlirValue get() { return value; } diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 457d8090c..75a72371e 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -133,7 +133,7 @@ struct PyPrintAccumulator { /// or binary. class PyFileAccumulator { public: - PyFileAccumulator(pybind11::object fileObject, bool binary) + PyFileAccumulator(const pybind11::object &fileObject, bool binary) : pyWriteFunction(fileObject.attr("write")), binary(binary) {} void *getUserData() { return this; } From 2674d6497f15f989e2b72117eaf43ed3ab90c43d Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 01:36:05 +0000 Subject: [PATCH 213/915] Apply clang-tidy fixes for readability-simplify-boolean-expr to MLIR (NFC) --- .../mlir/Bindings/Python/PybindAdaptors.h | 35 ++++--------------- 1 file changed, 7 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 05fb0844a..0340e9cc4 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -86,10 +86,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToAttribute(capsule.ptr()); - if (mlirAttributeIsNull(value)) { - return false; - } - return true; + return !mlirAttributeIsNull(value); } static handle cast(MlirAttribute v, return_value_policy, handle) { py::object capsule = @@ -117,10 +114,7 @@ struct type_caster { } py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToContext(capsule.ptr()); - if (mlirContextIsNull(value)) { - return false; - } - return true; + return !mlirContextIsNull(value); } }; @@ -132,10 +126,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToLocation(capsule.ptr()); - if (mlirLocationIsNull(value)) { - return false; - } - return true; + return !mlirLocationIsNull(value); } static handle cast(MlirLocation v, return_value_policy, handle) { py::object capsule = @@ -154,10 +145,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToModule(capsule.ptr()); - if (mlirModuleIsNull(value)) { - return false; - } - return true; + return !mlirModuleIsNull(value); } static handle cast(MlirModule v, return_value_policy, handle) { py::object capsule = @@ -176,10 +164,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToOperation(capsule.ptr()); - if (mlirOperationIsNull(value)) { - return false; - } - return true; + return !mlirOperationIsNull(value); } static handle cast(MlirOperation v, return_value_policy, handle) { if (v.ptr == nullptr) @@ -200,10 +185,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToPassManager(capsule.ptr()); - if (mlirPassManagerIsNull(value)) { - return false; - } - return true; + return !mlirPassManagerIsNull(value); } }; @@ -214,10 +196,7 @@ struct type_caster { bool load(handle src, bool) { py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToType(capsule.ptr()); - if (mlirTypeIsNull(value)) { - return false; - } - return true; + return !mlirTypeIsNull(value); } static handle cast(MlirType t, return_value_policy, handle) { py::object capsule = From 8ee6d886c4c2f9a4ee236a83db9c02403b4f3bed Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 14 Jan 2022 07:47:47 +0000 Subject: [PATCH 214/915] Disable the MLIR ExecutionEngine library when the native target is not configured The execution engine would not be functional anyway, we're already disabling the tests, this also disable the rest of the code. Anecdotally this reduces the number of static library built when the builtin target is disabled goes from 236 to 218. Here is the complete list of LLVM targets built when running `ninja check-mlir`: libLLVMAggressiveInstCombine.a libLLVMAnalysis.a libLLVMAsmParser.a libLLVMBinaryFormat.a libLLVMBitReader.a libLLVMBitstreamReader.a libLLVMBitWriter.a libLLVMCore.a libLLVMDebugInfoCodeView.a libLLVMDebugInfoDWARF.a libLLVMDemangle.a libLLVMFileCheck.a libLLVMFrontendOpenMP.a libLLVMInstCombine.a libLLVMIRReader.a libLLVMMC.a libLLVMMCParser.a libLLVMObject.a libLLVMProfileData.a libLLVMRemarks.a libLLVMScalarOpts.a libLLVMSupport.a libLLVMTableGen.a libLLVMTableGenGlobalISel.a libLLVMTextAPI.a libLLVMTransformUtils.a Differential Revision: https://reviews.llvm.org/D117287 --- mlir/lib/CAPI/CMakeLists.txt | 7 ++++++- mlir/python/CMakeLists.txt | 25 ++++++++++++++----------- 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 5545de691..393b49ecb 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -10,12 +10,16 @@ endfunction() add_subdirectory(Debug) add_subdirectory(Dialect) add_subdirectory(Conversion) -add_subdirectory(ExecutionEngine) add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(Registration) add_subdirectory(Transforms) +# Only enable the ExecutionEngine if the native target is configured in. +if(TARGET ${LLVM_NATIVE_ARCH}) + add_subdirectory(ExecutionEngine) +endif() + # Build the optional CAPI dylib. if(MLIR_BUILD_MLIR_C_DYLIB) message(STATUS "Building MLIR-C dylib") @@ -33,3 +37,4 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() + diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 60d60d4af..2a9d7a7f4 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -292,17 +292,20 @@ declare_mlir_python_extension(MLIRPythonExtension.Conversions MLIRCAPIConversion ) -declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine - MODULE_NAME _mlirExecutionEngine - ADD_TO_PARENT MLIRPythonSources.ExecutionEngine - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - ExecutionEngineModule.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIExecutionEngine -) +# Only enable the ExecutionEngine if the native target is configured in. +if(TARGET ${LLVM_NATIVE_ARCH}) + declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine + MODULE_NAME _mlirExecutionEngine + ADD_TO_PARENT MLIRPythonSources.ExecutionEngine + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + ExecutionEngineModule.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIExecutionEngine + ) +endif() declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses From de64ae5ff8f52e13e5934cc812a26b096ee17b0b Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 14 Jan 2022 17:25:41 +0100 Subject: [PATCH 215/915] [mlir] fix crash in PybindAdaptors.h The constructor function was being defined without indicating its "__init__" name, which made it interpret it as a regular fuction rather than a constructor. When overload resolution failed, Pybind would attempt to print the arguments actually passed to the function, including "self", which is not initialized since the constructor couldn't be called. This would result in "__repr__" being called with "self" referencing an uninitialized MLIR C API object, which in turn would cause undefined behavior when attempting to print in C++. Fix this by specifying the correct name. This in turn uncovers the fact the the mechanism used by PybindAdaptors.h to bind constructors directly as "__init__" functions taking "self" is deprecated by Pybind. The modern method requires using "py::init", which seems to rely on the C++ equivalent of the bound class to be available, which is not the case in PybindAdaptors.h. A deeper inspection shows that the deprecation concerns old-style pybind11 constructors that had to allocate the object using placement new with "self" as memory. The PybindAdaptors.h only provides extension classes and never allocates (the object construction is delegated to the base class), so it does not use the deprecated functionality. Use the implementation detail tag class to convince pybind11 that we are using the modern constructor binding method and suppress the warning. On top of that, the definition of the function was incorrectly indicated as the method on the "None" object instead of being the method of its parent class. This would result in a second problem when Pybind would attempt to print warnings pointing to the parent class since the "None" does not have a "__name__" field or its C API equivalent. Fix this by specifying the correct parent class by looking it up by name in the parent module. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D117325 --- .../mlir/Bindings/Python/PybindAdaptors.h | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 0340e9cc4..68942e560 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -320,7 +320,11 @@ class mlir_attribute_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + // __init__ make it more awkward to do generally). It is marked as + // `is_new_style_constructor` to suppress the deprecation warning from + // pybind11 related to placement-new since we are not doing any allocation + // here but relying on the superclass constructor that does "new-style" + // allocation for pybind11. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( @@ -336,7 +340,9 @@ class mlir_attribute_subclass : public pure_subclass { } superClass.attr("__init__")(self, otherType); }, - py::arg("cast_from_type"), py::is_method(py::none()), + py::name("__init__"), py::arg("cast_from_type"), + py::is_method(scope.attr(typeClassName)), + py::detail::is_new_style_constructor(), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; @@ -371,7 +377,11 @@ class mlir_type_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + // __init__ make it more awkward to do generally). It is marked as + // `is_new_style_constructor` to suppress the deprecation warning from + // pybind11 related to placement-new since we are not doing any allocation + // here but relying on the superclass constructor that does "new-style" + // allocation for pybind11. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( @@ -387,7 +397,9 @@ class mlir_type_subclass : public pure_subclass { } superClass.attr("__init__")(self, otherType); }, - py::arg("cast_from_type"), py::is_method(py::none()), + py::name("__init__"), py::arg("cast_from_type"), + py::is_method(scope.attr(typeClassName)), + py::detail::is_new_style_constructor(), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; From 93045c749cc41336a975b16e565edb61ff79b075 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 17 Jan 2022 10:55:36 +0100 Subject: [PATCH 216/915] Revert "[mlir] fix crash in PybindAdaptors.h" This reverts commit de64ae5ff8f52e13e5934cc812a26b096ee17b0b. Broke the buildbot. --- .../mlir/Bindings/Python/PybindAdaptors.h | 20 ++++--------------- 1 file changed, 4 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 68942e560..0340e9cc4 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -320,11 +320,7 @@ class mlir_attribute_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). It is marked as - // `is_new_style_constructor` to suppress the deprecation warning from - // pybind11 related to placement-new since we are not doing any allocation - // here but relying on the superclass constructor that does "new-style" - // allocation for pybind11. + // __init__ make it more awkward to do generally). std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( @@ -340,9 +336,7 @@ class mlir_attribute_subclass : public pure_subclass { } superClass.attr("__init__")(self, otherType); }, - py::name("__init__"), py::arg("cast_from_type"), - py::is_method(scope.attr(typeClassName)), - py::detail::is_new_style_constructor(), + py::arg("cast_from_type"), py::is_method(py::none()), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; @@ -377,11 +371,7 @@ class mlir_type_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). It is marked as - // `is_new_style_constructor` to suppress the deprecation warning from - // pybind11 related to placement-new since we are not doing any allocation - // here but relying on the superclass constructor that does "new-style" - // allocation for pybind11. + // __init__ make it more awkward to do generally). std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( @@ -397,9 +387,7 @@ class mlir_type_subclass : public pure_subclass { } superClass.attr("__init__")(self, otherType); }, - py::name("__init__"), py::arg("cast_from_type"), - py::is_method(scope.attr(typeClassName)), - py::detail::is_new_style_constructor(), + py::arg("cast_from_type"), py::is_method(py::none()), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; From 492190c1409e662d034bac1682add12d125c5bbe Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 17 Jan 2022 14:48:28 +0100 Subject: [PATCH 217/915] [mlir] fix crash in PybindAdaptors.h The constructor function was being defined without indicating its "__init__" name, which made it interpret it as a regular fuction rather than a constructor. When overload resolution failed, Pybind would attempt to print the arguments actually passed to the function, including "self", which is not initialized since the constructor couldn't be called. This would result in "__repr__" being called with "self" referencing an uninitialized MLIR C API object, which in turn would cause undefined behavior when attempting to print in C++. Fix this by specifying the correct name. This in turn uncovers the fact the the mechanism used by PybindAdaptors.h to bind constructors directly as "__init__" functions taking "self" is deprecated by Pybind. Instead, leverage the fact that the adaptors are intended for attrbutes/types that cannot have additional data members and are all ultimately instances of "PyAttribute"/"PyType" C++ class. In constructors of derived classes, construct an instance of the base class first, then steal its internal pointer to the C++ object to construct the instance of the derived class. On top of that, the definition of the function was incorrectly indicated as the method on the "None" object instead of being the method of its parent class. This would result in a second problem when Pybind would attempt to print warnings pointing to the parent class since the "None" does not have a "__name__" field or its C API equivalent. Fix this by specifying the correct parent class by looking it up by name in the parent module. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D117325 --- .../mlir/Bindings/Python/PybindAdaptors.h | 57 +++++++++++++++---- 1 file changed, 47 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 0340e9cc4..becbf0180 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -291,6 +291,29 @@ class pure_subclass { py::object get_class() const { return thisClass; } protected: + /// Defers to the constructor of the superClass the configuration of the + /// pybind11 object from the given arguments. Pybind11 has a special handling + /// for constructors such that they don't accept a "self" reference, unlike + /// Python "__init__" calls. Therefore, one cannot just call the "__init__" of + /// the parent class, which would require access to "self". Instead, create an + /// instance of the superclass and take its instance pointer to the base C++ + /// object to populate the instance pointer of the constructed object. Since + /// we only deal with _pure_ subclasses, this should be sufficient as derived + /// classes cannot have more data fields. + template + static void deferToSuperclassConstructor(py::detail::value_and_holder &vh, + py::object superClass, + Args &&...args) { + py::object super = superClass(std::forward(args)...); + py::detail::type_info *ti = + py::detail::get_type_info((PyTypeObject *)superClass.ptr()); + auto *instance = reinterpret_cast(super.ptr()); + + // Take ownership of the value pointer from the base class. + vh.value_ptr() = instance->get_value_and_holder(ti, true).value_ptr(); + super.release(); + } + py::object superClass; py::object thisClass; }; @@ -320,12 +343,16 @@ class mlir_attribute_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + // __init__ make it more awkward to do generally). It is marked as + // `is_new_style_constructor` to suppress the deprecation warning from + // pybind11 related to placement-new since we are not doing any allocation + // here but relying on the superclass constructor that does "new-style" + // allocation for pybind11. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { + [superClass, isaFunction, captureTypeName]( + py::detail::value_and_holder &vh, py::object otherType) { MlirAttribute rawAttribute = py::cast(otherType); if (!isaFunction(rawAttribute)) { auto origRepr = py::repr(otherType).cast(); @@ -334,9 +361,12 @@ class mlir_attribute_subclass : public pure_subclass { " (from " + origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + pure_subclass::deferToSuperclassConstructor(vh, superClass, + otherType); }, - py::arg("cast_from_type"), py::is_method(py::none()), + py::name("__init__"), py::arg("cast_from_type"), + py::is_method(scope.attr(typeClassName)), + py::detail::is_new_style_constructor(), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; @@ -371,12 +401,16 @@ class mlir_type_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + // __init__ make it more awkward to do generally). It is marked as + // `is_new_style_constructor` to suppress the deprecation warning from + // pybind11 related to placement-new since we are not doing any allocation + // here but relying on the superclass constructor that does "new-style" + // allocation for pybind11. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { + [superClass, isaFunction, captureTypeName]( + py::detail::value_and_holder &vh, py::object otherType) { MlirType rawType = py::cast(otherType); if (!isaFunction(rawType)) { auto origRepr = py::repr(otherType).cast(); @@ -385,9 +419,12 @@ class mlir_type_subclass : public pure_subclass { origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + pure_subclass::deferToSuperclassConstructor(vh, superClass, + otherType); }, - py::arg("cast_from_type"), py::is_method(py::none()), + py::name("__init__"), py::arg("cast_from_type"), + py::is_method(scope.attr(typeClassName)), + py::detail::is_new_style_constructor(), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; From 6d6878b908ff1fab18c4eb24aa17e8cf076b516c Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 18 Jan 2022 17:05:29 +0100 Subject: [PATCH 218/915] Revert "[mlir] fix crash in PybindAdaptors.h" This reverts commit 492190c1409e662d034bac1682add12d125c5bbe. --- .../mlir/Bindings/Python/PybindAdaptors.h | 57 ++++--------------- 1 file changed, 10 insertions(+), 47 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index becbf0180..0340e9cc4 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -291,29 +291,6 @@ class pure_subclass { py::object get_class() const { return thisClass; } protected: - /// Defers to the constructor of the superClass the configuration of the - /// pybind11 object from the given arguments. Pybind11 has a special handling - /// for constructors such that they don't accept a "self" reference, unlike - /// Python "__init__" calls. Therefore, one cannot just call the "__init__" of - /// the parent class, which would require access to "self". Instead, create an - /// instance of the superclass and take its instance pointer to the base C++ - /// object to populate the instance pointer of the constructed object. Since - /// we only deal with _pure_ subclasses, this should be sufficient as derived - /// classes cannot have more data fields. - template - static void deferToSuperclassConstructor(py::detail::value_and_holder &vh, - py::object superClass, - Args &&...args) { - py::object super = superClass(std::forward(args)...); - py::detail::type_info *ti = - py::detail::get_type_info((PyTypeObject *)superClass.ptr()); - auto *instance = reinterpret_cast(super.ptr()); - - // Take ownership of the value pointer from the base class. - vh.value_ptr() = instance->get_value_and_holder(ti, true).value_ptr(); - super.release(); - } - py::object superClass; py::object thisClass; }; @@ -343,16 +320,12 @@ class mlir_attribute_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). It is marked as - // `is_new_style_constructor` to suppress the deprecation warning from - // pybind11 related to placement-new since we are not doing any allocation - // here but relying on the superclass constructor that does "new-style" - // allocation for pybind11. + // __init__ make it more awkward to do generally). std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( - [superClass, isaFunction, captureTypeName]( - py::detail::value_and_holder &vh, py::object otherType) { + [superClass, isaFunction, captureTypeName](py::object self, + py::object otherType) { MlirAttribute rawAttribute = py::cast(otherType); if (!isaFunction(rawAttribute)) { auto origRepr = py::repr(otherType).cast(); @@ -361,12 +334,9 @@ class mlir_attribute_subclass : public pure_subclass { " (from " + origRepr + ")") .str()); } - pure_subclass::deferToSuperclassConstructor(vh, superClass, - otherType); + superClass.attr("__init__")(self, otherType); }, - py::name("__init__"), py::arg("cast_from_type"), - py::is_method(scope.attr(typeClassName)), - py::detail::is_new_style_constructor(), + py::arg("cast_from_type"), py::is_method(py::none()), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; @@ -401,16 +371,12 @@ class mlir_type_subclass : public pure_subclass { // Casting constructor. Note that defining an __init__ method is special // and not yet generalized on pure_subclass (it requires a somewhat // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). It is marked as - // `is_new_style_constructor` to suppress the deprecation warning from - // pybind11 related to placement-new since we are not doing any allocation - // here but relying on the superclass constructor that does "new-style" - // allocation for pybind11. + // __init__ make it more awkward to do generally). std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. py::cpp_function initCf( - [superClass, isaFunction, captureTypeName]( - py::detail::value_and_holder &vh, py::object otherType) { + [superClass, isaFunction, captureTypeName](py::object self, + py::object otherType) { MlirType rawType = py::cast(otherType); if (!isaFunction(rawType)) { auto origRepr = py::repr(otherType).cast(); @@ -419,12 +385,9 @@ class mlir_type_subclass : public pure_subclass { origRepr + ")") .str()); } - pure_subclass::deferToSuperclassConstructor(vh, superClass, - otherType); + superClass.attr("__init__")(self, otherType); }, - py::name("__init__"), py::arg("cast_from_type"), - py::is_method(scope.attr(typeClassName)), - py::detail::is_new_style_constructor(), + py::arg("cast_from_type"), py::is_method(py::none()), "Casts the passed type to this specific sub-type."); thisClass.attr("__init__") = initCf; From d572aa5a6bbc84a2983b3f0dd4444f9373cf90a9 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Thu, 13 Jan 2022 10:53:21 +0100 Subject: [PATCH 219/915] [mlir] Introduce Python bindings for the PDL dialect This change adds full python bindings for PDL, including types and operations with additional mixins to make operation construction more similar to the PDL syntax. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117458 --- mlir/include/mlir-c/Dialect/PDL.h | 2 + mlir/lib/Bindings/Python/DialectPDL.cpp | 102 +++++++ mlir/lib/CAPI/Dialect/PDL.cpp | 4 + mlir/python/CMakeLists.txt | 22 ++ .../mlir/_mlir_libs/_mlir/dialects/pdl.pyi | 64 ++++ mlir/python/mlir/dialects/PDLOps.td | 15 + mlir/python/mlir/dialects/_ods_common.py | 5 +- mlir/python/mlir/dialects/_pdl_ops_ext.py | 284 ++++++++++++++++++ mlir/python/mlir/dialects/pdl.py | 6 + 9 files changed, 502 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectPDL.cpp create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi create mode 100644 mlir/python/mlir/dialects/PDLOps.td create mode 100644 mlir/python/mlir/dialects/_pdl_ops_ext.py create mode 100644 mlir/python/mlir/dialects/pdl.py diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h index 5e0a2bc95..1b1528999 100644 --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -49,6 +49,8 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAPDLRangeType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGet(MlirType elementType); +MLIR_CAPI_EXPORTED MlirType mlirPDLRangeTypeGetElementType(MlirType type); + //===---------------------------------------------------------------------===// // TypeType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp new file mode 100644 index 000000000..8d0b1014a --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -0,0 +1,102 @@ +//===- DialectPDL.cpp - 'pdl' dialect submodule ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/PDL.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +void populateDialectPDLSubmodule(const pybind11::module &m) { + //===-------------------------------------------------------------------===// + // PDLType + //===-------------------------------------------------------------------===// + + auto pdlType = mlir_type_subclass(m, "PDLType", mlirTypeIsAPDLType); + + //===-------------------------------------------------------------------===// + // AttributeType + //===-------------------------------------------------------------------===// + + auto attributeType = + mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); + attributeType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLAttributeTypeGet(ctx)); + }, + "Get an instance of AttributeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // OperationType + //===-------------------------------------------------------------------===// + + auto operationType = + mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); + operationType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLOperationTypeGet(ctx)); + }, + "Get an instance of OperationType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // RangeType + //===-------------------------------------------------------------------===// + + auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); + rangeType.def_classmethod( + "get", + [](py::object cls, MlirType elementType) { + return cls(mlirPDLRangeTypeGet(elementType)); + }, + "Gets an instance of RangeType in the same context as the provided " + "element type.", + py::arg("cls"), py::arg("element_type")); + rangeType.def_property_readonly( + "element_type", + [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, + "Get the element type."); + + //===-------------------------------------------------------------------===// + // TypeType + //===-------------------------------------------------------------------===// + + auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); + typeType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLTypeTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // ValueType + //===-------------------------------------------------------------------===// + + auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); + valueType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirPDLValueTypeGet(ctx)); + }, + "Get an instance of TypeType in given context.", py::arg("cls"), + py::arg("context") = py::none()); +} + +PYBIND11_MODULE(_mlirDialectsPDL, m) { + m.doc() = "MLIR PDL dialect."; + populateDialectPDLSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp index 42b4ec24b..497b2cb1f 100644 --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -60,6 +60,10 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) { return wrap(pdl::RangeType::get(unwrap(elementType))); } +MlirType mlirPDLRangeTypeGetElementType(MlirType type) { + return wrap(unwrap(type).cast().getElementType()); +} + //===---------------------------------------------------------------------===// // TypeType //===---------------------------------------------------------------------===// diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 2a9d7a7f4..77d6b0832 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -123,6 +123,15 @@ declare_mlir_python_sources( dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.pdl + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/pdl.py + dialects/_pdl_ops_ext.py + _mlir_libs/_mlir/dialects/pdl.pyi) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -243,6 +252,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind MLIRCAPIQuant ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind + MODULE_NAME _mlirDialectsPDL + ADD_TO_PARENT MLIRPythonSources.Dialects.pdl + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectPDL.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIPDL +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MODULE_NAME _mlirDialectsSparseTensor ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi new file mode 100644 index 000000000..8ec944d19 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi @@ -0,0 +1,64 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from mlir.ir import Type, Context + +__all__ = [ + 'PDLType', + 'AttributeType', + 'OperationType', + 'RangeType', + 'TypeType', + 'ValueType', +] + + +class PDLType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + +class AttributeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> AttributeType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> OperationType: ... + + +class RangeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(element_type: Type) -> RangeType: ... + + @property + def element_type(self) -> Type: ... + + +class TypeType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> TypeType: ... + + +class ValueType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> ValueType: ... diff --git a/mlir/python/mlir/dialects/PDLOps.td b/mlir/python/mlir/dialects/PDLOps.td new file mode 100644 index 000000000..e4e6a83cd --- /dev/null +++ b/mlir/python/mlir/dialects/PDLOps.td @@ -0,0 +1,15 @@ +//===-- PDLOps.td - Entry point for PDLOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_PDL_OPS +#define PYTHON_BINDINGS_PDL_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/PDL/IR/PDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 6bb84e978..0c66593ce 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -144,7 +144,8 @@ def get_op_result_or_value( def get_op_results_or_values( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _Sequence[_cext.ir.Value]] + arg: _Union[_cext.ir.OpView, _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: """Returns the given sequence of values or the results of the given op. @@ -157,4 +158,4 @@ def get_op_results_or_values( elif isinstance(arg, _cext.ir.Operation): return arg.results else: - return arg + return [get_op_result_or_value(element) for element in arg] diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py new file mode 100644 index 000000000..364db5385 --- /dev/null +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -0,0 +1,284 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional, Sequence, List, Mapping +from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values + + +def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr: + """Converts the given value to signless integer attribute of given bit width.""" + if isinstance(value, int): + ty = IntegerType.get_signless(bits) + return IntegerAttr.get(ty, value) + else: + return value + + +def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr: + """Converts the given value to array attribute.""" + if isinstance(attrs, ArrayAttr): + return attrs + else: + return ArrayAttr.get(list(attrs)) + + +def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr: + """Converts the given value to string array attribute.""" + if isinstance(attrs, ArrayAttr): + return attrs + else: + return ArrayAttr.get([StringAttr.get(s) for s in attrs]) + + +def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]: + """Converts the given value to string attribute.""" + if isinstance(name, str): + return StringAttr.get(name) + else: + return name + + +def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr: + """Converts the given value to type attribute.""" + if isinstance(type, Type): + return TypeAttr.get(type) + else: + return type + + +class ApplyNativeConstraintOp: + """Specialization for PDL apply native constraint op class.""" + + def __init__(self, + name: Union[str, StringAttr], + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + name = _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(name, args, params, loc=loc, ip=ip) + + +class ApplyNativeRewriteOp: + """Specialization for PDL apply native rewrite op class.""" + + def __init__(self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + name = _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(results, name, args, params, loc=loc, ip=ip) + + +class AttributeOp: + """Specialization for PDL attribute op class.""" + + def __init__(self, + type: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_value(type) + result = pdl.AttributeType.get() + super().__init__(result, type, value, loc=loc, ip=ip) + + +class EraseOp: + """Specialization for PDL erase op class.""" + + def __init__(self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) + + +class OperandOp: + """Specialization for PDL operand op class.""" + + def __init__(self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, type, loc=loc, ip=ip) + + +class OperandsOp: + """Specialization for PDL operands op class.""" + + def __init__(self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, types, loc=loc, ip=ip) + + +class OperationOp: + """Specialization for PDL operand op class.""" + + def __init__(self, + name: Optional[Union[str, StringAttr]] = None, + args: Sequence[Union[OpView, Operation, Value]] = [], + attributes: Mapping[str, Union[OpView, Operation, Value]] = {}, + types: Sequence[Union[OpView, Operation, Value]] = [], + *, + loc=None, + ip=None): + name = name if name is None else _get_str_attr(name) + args = _get_values(args) + attributeNames = [] + attributeValues = [] + for attrName, attrValue in attributes.items(): + attributeNames.append(StringAttr.get(attrName)) + attributeValues.append(_get_value(attrValue)) + attributeNames = ArrayAttr.get(attributeNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip) + + +class PatternOp: + """Specialization for PDL pattern op class.""" + + def __init__(self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None): + """Creates an PDL `pattern` operation.""" + name_attr = None if name is None else _get_str_attr(name) + benefit_attr = _get_int_attr(16, benefit) + super().__init__(benefit_attr, name_attr, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] + + +class ReplaceOp: + """Specialization for PDL replace op class.""" + + def __init__(self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Sequence[Union[OpView, Operation, Value]] = [], + loc=None, + ip=None): + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_op, with_values, loc=loc, ip=ip) + + +class ResultOp: + """Specialization for PDL result op class.""" + + def __init__(self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None): + index = _get_int_attr(32, index) + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) + + +class ResultsOp: + """Specialization for PDL results op class.""" + + def __init__(self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None): + parent = _get_value(parent) + index = index if index is None else _get_int_attr(32, index) + super().__init__(result, parent, index, loc=loc, ip=ip) + + +class RewriteOp: + """Specialization for PDL rewrite op class.""" + + def __init__(self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Sequence[Union[OpView, Operation, Value]] = [], + params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + *, + loc=None, + ip=None): + root = root if root is None else _get_value(root) + name = name if name is None else _get_str_attr(name) + args = _get_values(args) + params = params if params is None else _get_array_attr(params) + super().__init__(root, name, args, params, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] + + +class TypeOp: + """Specialization for PDL type op class.""" + + def __init__(self, + type: Optional[Union[TypeAttr, Type]] = None, + *, + loc=None, + ip=None): + type = type if type is None else _get_type_attr(type) + result = pdl.TypeType.get() + super().__init__(result, type, loc=loc, ip=ip) + + +class TypesOp: + """Specialization for PDL types op class.""" + + def __init__(self, + types: Sequence[Union[TypeAttr, Type]] = [], + *, + loc=None, + ip=None): + types = _get_array_attr([_get_type_attr(ty) for ty in types]) + types = None if not types else types + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, types, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py new file mode 100644 index 000000000..dda2b7d65 --- /dev/null +++ b/mlir/python/mlir/dialects/pdl.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._pdl_ops_gen import * +from .._mlir_libs._mlirDialectsPDL import * From 3cd3e91c33b2b0ade9e6f6485aa177b2bbcc4961 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Wed, 19 Jan 2022 13:42:14 +0100 Subject: [PATCH 220/915] [mlir] Fix PDL python bindings build Fixes incorrect build definition for the bindings for the PDL dialect. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117657 --- mlir/python/CMakeLists.txt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 77d6b0832..59ddd83fb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -123,14 +123,15 @@ declare_mlir_python_sources( dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.pdl +declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/PDLOps.td SOURCES dialects/pdl.py dialects/_pdl_ops_ext.py - _mlir_libs/_mlir/dialects/pdl.pyi) + _mlir_libs/_mlir/dialects/pdl.pyi + DIALECT_NAME pdl) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects From dde381b594b00c8cc30ae9d9767489fc7f162748 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 19 Jan 2022 12:21:42 +0100 Subject: [PATCH 221/915] [mlir] Rework subclass construction in PybindAdaptors.h The constructor function was being defined without indicating its "__init__" name, which made it interpret it as a regular fuction rather than a constructor. When overload resolution failed, Pybind would attempt to print the arguments actually passed to the function, including "self", which is not initialized since the constructor couldn't be called. This would result in "__repr__" being called with "self" referencing an uninitialized MLIR C API object, which in turn would cause undefined behavior when attempting to print in C++. Even if the correct name is provided, the mechanism used by PybindAdaptors.h to bind constructors directly as "__init__" functions taking "self" is deprecated by Pybind. The new mechanism does not seem to have access to a fully-constructed "self" object (i.e., the constructor in C++ takes a `pybind11::detail::value_and_holder` that cannot be forwarded back to Python). Instead, redefine "__new__" to perform the required checks (there are no additional initialization needed for attributes and types as they are all wrappers around a C++ pointer). "__new__" can call its equivalent on a superclass without needing "self". Bump pybind11 dependency to 3.8.0, which is the first version that allows one to redefine "__new__". Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D117646 --- .../mlir/Bindings/Python/PybindAdaptors.h | 65 ++++++++++--------- mlir/python/mlir/dialects/python_test.py | 2 +- mlir/python/requirements.txt | 3 +- 3 files changed, 38 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 0340e9cc4..73cc7e441 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -314,31 +314,34 @@ class mlir_attribute_subclass : public pure_subclass { /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { - MlirAttribute rawAttribute = py::cast(otherType); + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherAttribute) { + MlirAttribute rawAttribute = py::cast(otherAttribute); if (!isaFunction(rawAttribute)) { - auto origRepr = py::repr(otherType).cast(); + auto origRepr = py::repr(otherAttribute).cast(); throw std::invalid_argument( (llvm::Twine("Cannot cast attribute to ") + captureTypeName + " (from " + origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod( @@ -366,17 +369,21 @@ class mlir_type_subclass : public pure_subclass { /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superClass) - : pure_subclass(scope, typeClassName, superClass) { - // Casting constructor. Note that defining an __init__ method is special - // and not yet generalized on pure_subclass (it requires a somewhat - // different cpp_function and other requirements on chaining to super - // __init__ make it more awkward to do generally). + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. std::string captureTypeName( typeClassName); // As string in case if typeClassName is not static. - py::cpp_function initCf( - [superClass, isaFunction, captureTypeName](py::object self, - py::object otherType) { + py::cpp_function newCf( + [superCls, isaFunction, captureTypeName](py::object cls, + py::object otherType) { MlirType rawType = py::cast(otherType); if (!isaFunction(rawType)) { auto origRepr = py::repr(otherType).cast(); @@ -385,11 +392,11 @@ class mlir_type_subclass : public pure_subclass { origRepr + ")") .str()); } - superClass.attr("__init__")(self, otherType); + py::object self = superCls.attr("__new__")(cls, otherType); + return self; }, - py::arg("cast_from_type"), py::is_method(py::none()), - "Casts the passed type to this specific sub-type."); - thisClass.attr("__init__") = initCf; + py::name("__new__"), py::arg("cls"), py::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; // 'isinstance' method. def_staticmethod( diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 82c01d5a0..9f560c205 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * - +from .._mlir_libs._mlirPythonTest import TestAttr, TestType def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index f76dcf676..0cc86af2c 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ numpy -# Version 2.7.0 excluded: https://github.com/pybind/pybind11/issues/3136 -pybind11>=2.6.0,!=2.7.0 +pybind11>=2.8.0 PyYAML From c4ad3d01763c0527e15d7bf92ed8ccb72630ffea Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 18 Jan 2022 18:28:51 -0800 Subject: [PATCH 222/915] [mlir] Make locations required when adding/creating block arguments BlockArguments gained the ability to have locations attached a while ago, but they have always been optional. This goes against the core tenant of MLIR where location information is a requirement, so this commit updates the API to require locations. Fixes #53279 Differential Revision: https://reviews.llvm.org/D117633 --- mlir/include/mlir-c/IR.h | 6 +++-- mlir/lib/Bindings/Python/IRCore.cpp | 34 ++++++++++++++++++++++++----- mlir/lib/CAPI/Dialect/Linalg.cpp | 7 ++++-- mlir/lib/CAPI/IR/IR.cpp | 10 +++++---- 4 files changed, 44 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1d884e634..e5c20ae70 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -531,7 +531,8 @@ MLIR_CAPI_EXPORTED MlirRegion mlirRegionGetNextInOperation(MlirRegion region); /// Creates a new empty block with the given argument types and transfers /// ownership to the caller. MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, - MlirType const *args); + MlirType const *args, + MlirLocation const *locs); /// Takes a block owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block); @@ -590,7 +591,8 @@ MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumArguments(MlirBlock block); /// Appends an argument of the specified type to the block. Returns the newly /// added argument. MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, - MlirType type); + MlirType type, + MlirLocation loc); /// Returns `pos`-th argument of the block. MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 729431b98..621c09502 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -325,12 +325,18 @@ class PyBlockList { PyBlock appendBlock(const py::args &pyArgTypes) { operation->checkValid(); llvm::SmallVector argTypes; + llvm::SmallVector argLocs; argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); for (auto &pyArg : pyArgTypes) { argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirBlock block = + mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } @@ -2717,12 +2723,18 @@ void mlir::python::populateIRCore(py::module &m) { [](PyRegion &parent, py::list pyArgTypes) { parent.checkValid(); llvm::SmallVector argTypes; + llvm::SmallVector argLocs; argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); for (auto &pyArg : pyArgTypes) { argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, @@ -2734,12 +2746,18 @@ void mlir::python::populateIRCore(py::module &m) { [](PyBlock &self, py::args pyArgTypes) { self.checkValid(); llvm::SmallVector argTypes; + llvm::SmallVector argLocs; argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); for (auto &pyArg : pyArgTypes) { argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); @@ -2751,12 +2769,18 @@ void mlir::python::populateIRCore(py::module &m) { [](PyBlock &self, py::args pyArgTypes) { self.checkValid(); llvm::SmallVector argTypes; + llvm::SmallVector argLocs; argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); for (auto &pyArg : pyArgTypes) { argTypes.push_back(pyArg.cast()); - } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); + } + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 6c5ba9a88..8862b6b15 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -28,12 +28,15 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { "Expected Linalg op with 0 blocks"); SmallVector argTypes; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) + SmallVector argLocs; + for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); + argLocs.push_back(opOperand->get().getLoc()); + } ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); Region ®ion = op->getRegion(0); - Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes); + Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); fun(b, *body); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 13b7673e2..9b60b11fe 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -535,10 +535,11 @@ void mlirRegionDestroy(MlirRegion region) { // Block API. //===----------------------------------------------------------------------===// -MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args) { +MlirBlock mlirBlockCreate(intptr_t nArgs, MlirType const *args, + MlirLocation const *locs) { Block *b = new Block; for (intptr_t i = 0; i < nArgs; ++i) - b->addArgument(unwrap(args[i])); + b->addArgument(unwrap(args[i]), unwrap(locs[i])); return wrap(b); } @@ -618,8 +619,9 @@ intptr_t mlirBlockGetNumArguments(MlirBlock block) { return static_cast(unwrap(block)->getNumArguments()); } -MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type) { - return wrap(unwrap(block)->addArgument(unwrap(type))); +MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, + MlirLocation loc) { + return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); } MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { From 6db06b76e169e6fcc2bd793830266ff7d1a1e3e0 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Fri, 21 Jan 2022 05:21:00 +0000 Subject: [PATCH 223/915] [mlir][python] 8b/16b DenseIntElements access This extends dense attribute element access to support 8b and 16b ints. Also extends the corresponding parts of the C api. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D117731 --- mlir/include/mlir-c/BuiltinAttributes.h | 8 ++++++++ mlir/lib/Bindings/Python/IRAttributes.cpp | 12 ++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 16 ++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 5839cd3d2..973b7e994 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -355,6 +355,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt8Get( MlirType shapedType, intptr_t numElements, const uint8_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt8Get( MlirType shapedType, intptr_t numElements, const int8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt16Get( + MlirType shapedType, intptr_t numElements, const int16_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrUInt32Get( MlirType shapedType, intptr_t numElements, const uint32_t *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrInt32Get( @@ -416,6 +420,10 @@ MLIR_CAPI_EXPORTED int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int16_t +mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint16_t +mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint32_t diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index fd44ffe6b..5d87641c3 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -673,6 +673,12 @@ class PyDenseIntElementsAttribute if (width == 1) { return mlirDenseElementsAttrGetBoolValue(*this, pos); } + if (width == 8) { + return mlirDenseElementsAttrGetUInt8Value(*this, pos); + } + if (width == 16) { + return mlirDenseElementsAttrGetUInt16Value(*this, pos); + } if (width == 32) { return mlirDenseElementsAttrGetUInt32Value(*this, pos); } @@ -683,6 +689,12 @@ class PyDenseIntElementsAttribute if (width == 1) { return mlirDenseElementsAttrGetBoolValue(*this, pos); } + if (width == 8) { + return mlirDenseElementsAttrGetInt8Value(*this, pos); + } + if (width == 16) { + return mlirDenseElementsAttrGetInt16Value(*this, pos); + } if (width == 32) { return mlirDenseElementsAttrGetInt32Value(*this, pos); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index c20548bd4..7b718da88 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -426,6 +426,16 @@ MlirAttribute mlirDenseElementsAttrInt8Get(MlirType shapedType, const int8_t *elements) { return getDenseAttribute(shapedType, numElements, elements); } +MlirAttribute mlirDenseElementsAttrUInt16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} +MlirAttribute mlirDenseElementsAttrInt16Get(MlirType shapedType, + intptr_t numElements, + const int16_t *elements) { + return getDenseAttribute(shapedType, numElements, elements); +} MlirAttribute mlirDenseElementsAttrUInt32Get(MlirType shapedType, intptr_t numElements, const uint32_t *elements) { @@ -530,6 +540,12 @@ int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { return unwrap(attr).cast().getValues()[pos]; } +int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getValues()[pos]; +} +uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getValues()[pos]; +} int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { return unwrap(attr).cast().getValues()[pos]; } From c27f53b272bbd10f1c13ec2d71f97bcf7b2ebfe6 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 13 Jan 2022 16:27:28 -0800 Subject: [PATCH 224/915] Upstream MLIR PyTACO implementation. Add TACO tests to test/Integration/Dialect/SparseTensor/taco. Add the MLIR PyTACO implementation as tools under the directory. Reviewed By: aartbik, mehdi_amini Differential Revision: https://reviews.llvm.org/D117260 --- mlir/python/requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 0cc86af2c..991e8eb24 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy pybind11>=2.8.0 PyYAML +dataclasses From 0d7c67ff136f9183a7b730e66bfe4c8879cf4504 Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Mon, 24 Jan 2022 12:44:47 +0100 Subject: [PATCH 225/915] [mlir] Fix broken __repr__ implementation in Linalg OpDSL Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D118027 --- mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index ddbebb29f..f6f3e0144 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -468,7 +468,7 @@ def visit_tensor_exprs(self, callback): self.arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.type_fn)}({type_var}, {self.arg})" + return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" class const(TensorExpression): From c8d2717771e23659b4adf76422a6a650d1965ea8 Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Sat, 29 Jan 2022 18:41:10 -0800 Subject: [PATCH 226/915] Replace OwningModuleRef with OwningOpRef This addresses a TODO in BuiltinOps.h. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D118574 --- mlir/lib/CAPI/IR/IR.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 9b60b11fe..26af65ec7 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -177,7 +177,8 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) { } MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { - OwningModuleRef owning = parseSourceString(unwrap(module), unwrap(context)); + OwningOpRef owning = + parseSourceString(unwrap(module), unwrap(context)); if (!owning) return MlirModule{nullptr}; return MlirModule{owning.release().getOperation()}; @@ -192,8 +193,9 @@ MlirBlock mlirModuleGetBody(MlirModule module) { } void mlirModuleDestroy(MlirModule module) { - // Transfer ownership to an OwningModuleRef so that its destructor is called. - OwningModuleRef(unwrap(module)); + // Transfer ownership to an OwningOpRef so that its destructor is + // called. + OwningOpRef(unwrap(module)); } MlirOperation mlirModuleGetOperation(MlirModule module) { From 8651d7e15fafa94d9aaf3001023624c6e345ad77 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 31 Jan 2022 19:10:51 +0900 Subject: [PATCH 227/915] [mlir][vector][NFC] Split into IR, Transforms and Utils This reduces the dependencies of the MLIRVector target and makes the dialect consistent with other dialects. Differential Revision: https://reviews.llvm.org/D118533 --- mlir/python/mlir/dialects/VectorOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td index b06668bdf..267c2b2a0 100644 --- a/mlir/python/mlir/dialects/VectorOps.td +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -10,6 +10,6 @@ #define PYTHON_BINDINGS_VECTOR_OPS include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/Vector/VectorOps.td" +include "mlir/Dialect/Vector/IR/VectorOps.td" #endif From 98392663c13dbcd705decc80a16a5ecb659de226 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 19 Jan 2022 13:42:49 +0100 Subject: [PATCH 228/915] [mlir] Better error message in PybindAdaptors.h When attempting to cast a pybind11 handle to an MLIR C API object through capsules, the binding code would attempt to directly access the "_CAPIPtr" attribute on the object, leading to a rather obscure AttributeError when the attribute was missing, e.g., on non-MLIR types. Check for its presence and throw a TypeError instead. Depends On D117646 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D117658 --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 73cc7e441..9d5a512a4 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -42,13 +42,19 @@ struct type_caster> : optional_caster> {}; /// an explicit Capsule (which can happen when two C APIs are communicating /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR /// attribute (through which supported MLIR Python API objects export their -/// contained API pointer as a capsule). This is intended to be used from -/// type casters, which are invoked with a raw handle (unowned). The returned -/// object's lifetime may not extend beyond the apiObject handle without -/// explicitly having its refcount increased (i.e. on return). +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). static py::object mlirApiObjectToCapsule(py::handle apiObject) { if (PyCapsule_CheckExact(apiObject.ptr())) return py::reinterpret_borrow(apiObject); + if (!py::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + auto repr = py::repr(apiObject).cast(); + throw py::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").").str()); + } return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); } From 192429a5df70def9a38ac623319f60b6e3710301 Mon Sep 17 00:00:00 2001 From: Daniel Resnick Date: Wed, 26 Jan 2022 17:13:24 -0700 Subject: [PATCH 229/915] [mlir][capi] Add DialectRegistry to MLIR C-API Exposes mlir::DialectRegistry to the C API as MlirDialectRegistry along with helper functions. A hook has been added to MlirDialectHandle that inserts the dialect into a registry. A future possible change is removing mlirDialectHandleRegisterDialect in favor of using mlirDialectHandleInsertDialect, which it is now implemented with. Differential Revision: https://reviews.llvm.org/D118293 --- mlir/include/mlir-c/IR.h | 22 ++++++++++++++++++++++ mlir/include/mlir-c/Registration.h | 5 +++++ mlir/include/mlir/CAPI/IR.h | 1 + mlir/include/mlir/CAPI/Registration.h | 16 ++++++++-------- mlir/lib/CAPI/IR/DialectHandle.cpp | 9 ++++++++- mlir/lib/CAPI/IR/IR.cpp | 17 +++++++++++++++++ 6 files changed, 61 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index e5c20ae70..d99955466 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -50,6 +50,7 @@ extern "C" { DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); +DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); @@ -108,6 +109,11 @@ mlirContextGetAllowUnregisteredDialects(MlirContext context); MLIR_CAPI_EXPORTED intptr_t mlirContextGetNumRegisteredDialects(MlirContext context); +/// Append the contents of the given dialect registry to the registry associated +/// with the context. +MLIR_CAPI_EXPORTED void +mlirContextAppendDialectRegistry(MlirContext ctx, MlirDialectRegistry registry); + /// Returns the number of dialects loaded by the context. MLIR_CAPI_EXPORTED intptr_t @@ -152,6 +158,22 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1, /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +/// Creates a dialect registry and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate(); + +/// Checks if the dialect registry is null. +static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) { + return !registry.ptr; +} + +/// Takes a dialect registry owned by the caller and destroys it. +MLIR_CAPI_EXPORTED void +mlirDialectRegistryDestroy(MlirDialectRegistry registry); + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h index e8604329f..442449626 100644 --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -44,6 +44,11 @@ typedef struct MlirDialectHandle MlirDialectHandle; MLIR_CAPI_EXPORTED MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + /// Registers the dialect associated with the provided dialect handle. MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, MlirContext); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index af7ae8977..06cf7762a 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -22,6 +22,7 @@ DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) +DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h index ac909d1dd..e57023d30 100644 --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -21,23 +21,23 @@ //===----------------------------------------------------------------------===// /// Hooks for dynamic discovery of dialects. -typedef void (*MlirContextRegisterDialectHook)(MlirContext context); +typedef void (*MlirDialectRegistryInsertDialectHook)( + MlirDialectRegistry registry); typedef MlirDialect (*MlirContextLoadDialectHook)(MlirContext context); typedef MlirStringRef (*MlirDialectGetNamespaceHook)(); /// Structure of dialect registration hooks. struct MlirDialectRegistrationHooks { - MlirContextRegisterDialectHook registerHook; + MlirDialectRegistryInsertDialectHook insertHook; MlirContextLoadDialectHook loadHook; MlirDialectGetNamespaceHook getNamespaceHook; }; typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; #define MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Name, Namespace, ClassName) \ - static void mlirContextRegister##Name##Dialect(MlirContext context) { \ - mlir::DialectRegistry registry; \ - registry.insert(); \ - unwrap(context)->appendDialectRegistry(registry); \ + static void mlirDialectRegistryInsert##Name##Dialect( \ + MlirDialectRegistry registry) { \ + unwrap(registry)->insert(); \ } \ static MlirDialect mlirContextLoad##Name##Dialect(MlirContext context) { \ return wrap(unwrap(context)->getOrLoadDialect()); \ @@ -47,8 +47,8 @@ typedef struct MlirDialectRegistrationHooks MlirDialectRegistrationHooks; } \ MlirDialectHandle mlirGetDialectHandle__##Namespace##__() { \ static MlirDialectRegistrationHooks hooks = { \ - mlirContextRegister##Name##Dialect, mlirContextLoad##Name##Dialect, \ - mlir##Name##DialectGetNamespace}; \ + mlirDialectRegistryInsert##Name##Dialect, \ + mlirContextLoad##Name##Dialect, mlir##Name##DialectGetNamespace}; \ return MlirDialectHandle{&hooks}; \ } diff --git a/mlir/lib/CAPI/IR/DialectHandle.cpp b/mlir/lib/CAPI/IR/DialectHandle.cpp index fb972316e..19f64d948 100644 --- a/mlir/lib/CAPI/IR/DialectHandle.cpp +++ b/mlir/lib/CAPI/IR/DialectHandle.cpp @@ -17,9 +17,16 @@ MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle handle) { return unwrap(handle)->getNamespaceHook(); } +void mlirDialectHandleInsertDialect(MlirDialectHandle handle, + MlirDialectRegistry registry) { + unwrap(handle)->insertHook(registry); +} + void mlirDialectHandleRegisterDialect(MlirDialectHandle handle, MlirContext ctx) { - unwrap(handle)->registerHook(ctx); + mlir::DialectRegistry registry; + mlirDialectHandleInsertDialect(handle, wrap(®istry)); + unwrap(ctx)->appendDialectRegistry(registry); } MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle handle, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 26af65ec7..c067b202b 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -53,6 +53,11 @@ intptr_t mlirContextGetNumRegisteredDialects(MlirContext context) { return static_cast(unwrap(context)->getAvailableDialects().size()); } +void mlirContextAppendDialectRegistry(MlirContext ctx, + MlirDialectRegistry registry) { + unwrap(ctx)->appendDialectRegistry(*unwrap(registry)); +} + // TODO: expose a cheaper way than constructing + sorting a vector only to take // its size. intptr_t mlirContextGetNumLoadedDialects(MlirContext context) { @@ -88,6 +93,18 @@ MlirStringRef mlirDialectGetNamespace(MlirDialect dialect) { return wrap(unwrap(dialect)->getNamespace()); } +//===----------------------------------------------------------------------===// +// DialectRegistry API. +//===----------------------------------------------------------------------===// + +MlirDialectRegistry mlirDialectRegistryCreate() { + return wrap(new DialectRegistry()); +} + +void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { + delete unwrap(registry); +} + //===----------------------------------------------------------------------===// // Printing flags API. //===----------------------------------------------------------------------===// From cfdfee2692cf129b361546e7451817f931495ed9 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 5 Feb 2022 21:50:27 -0800 Subject: [PATCH 230/915] [mlir] Fixup python bindings after splitting cf ops from std. --- mlir/python/CMakeLists.txt | 8 ++++++++ mlir/python/mlir/dialects/ControlFlowOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/cf.py | 5 +++++ 3 files changed, 28 insertions(+) create mode 100644 mlir/python/mlir/dialects/ControlFlowOps.td create mode 100644 mlir/python/mlir/dialects/cf.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 59ddd83fb..bf379e9b2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -72,6 +72,14 @@ declare_mlir_dialect_python_bindings( dialects/_builtin_ops_ext.py DIALECT_NAME builtin) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ControlFlowOps.td + SOURCES + dialects/cf.py + DIALECT_NAME cf) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/ControlFlowOps.td b/mlir/python/mlir/dialects/ControlFlowOps.td new file mode 100644 index 000000000..1bb4d41f2 --- /dev/null +++ b/mlir/python/mlir/dialects/ControlFlowOps.td @@ -0,0 +1,15 @@ +//===-- ControlFlowOps.td - Python ControlFlowOps bindings -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_CONTROL_FLOW_OPS +#define PYTHON_BINDINGS_CONTROL_FLOW_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/cf.py b/mlir/python/mlir/dialects/cf.py new file mode 100644 index 000000000..c2e357a8e --- /dev/null +++ b/mlir/python/mlir/dialects/cf.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._cf_ops_gen import * From 73be0a6fc1c083dac8c26504081967f2bbfe644e Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 11 Feb 2022 08:20:37 +0000 Subject: [PATCH 231/915] [mlir][OpDSL] Add support for basic rank polymorphism. Previously, OpDSL did not support rank polymorphism, which required a separate implementation of linalg.fill. This revision extends OpDSL to support rank polymorphism for a limited class of operations that access only scalars and tensors of rank zero. At operation instantiation time, it scales these scalar computations to multi-dimensional pointwise computations by replacing the empty indexing maps with identity index maps. The revision does not change the DSL itself, instead it adapts the Python emitter and the YAML generator to generate different indexing maps and and iterators depending on the rank of the first output. Additionally, the revision introduces a `linalg.fill_tensor` operation that in a future revision shall replace the current handwritten `linalg.fill` operation. `linalg.fill_tensor` is thus only temporarily available and will be renamed to `linalg.fill`. Reviewed By: nicolasvasilache, stellaraccident Differential Revision: https://reviews.llvm.org/D119003 --- .../dialects/linalg/opdsl/lang/emitter.py | 32 +++++++++++++------ .../linalg/opdsl/ops/core_named_ops.py | 11 +++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 22568c8b6..643bcaa5c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -14,6 +14,7 @@ from .scalar_expr import * from .config import * +from .comprehension import * import numpy as np __all__ = [ @@ -132,6 +133,25 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.Scalar: + indexing_maps.append(scalar_map) + if (arg_def.operand_def.kind == OperandKind.InputTensor or + arg_def.operand_def.kind == OperandKind.OutputTensor): + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps]) + generic_op = linalg.GenericOp( result_tensors=result_types, inputs=ins, @@ -172,19 +192,13 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, raise NotImplementedError( f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") + # Set the index attributes used to compute the indexing maps. named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - linalg.fill_builtin_region(named_op.operation) - # Note: mlir-linalg-ods-yaml-gen.cpp uses a special linalg.memoized_indexing_maps - # attribute that the non-yaml path does not. The non-yaml path hardcodes the - # indexing_maps in C++ directly. - named_op.operation.attributes[ - "linalg.memoized_indexing_maps"] = indexing_maps_attr - # iterator_types are hardcoded in C++ both in the yaml and non-yaml path. - - # Additionally set all named attributes. for name, value in index_attributes.items(): named_op.operation.attributes[name] = value + linalg.fill_builtin_region(named_op.operation) + if len(result_types) == 1: return named_op.result else: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index d3651bd76..80a8fb6cc 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -627,6 +627,17 @@ def pooling_ndhwc_min( D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): + """Fills the output tensor with the given value. + + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + O[None] = TypeFn.cast(U, value) + + @linalg_structured_op def fill_rng_2d( min=ScalarDef(F64), From 517013dd97fe1a5182b81d38391708ac3bbcaa0c Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 14 Feb 2022 10:55:23 +0000 Subject: [PATCH 232/915] [mlir][OpDSL] Consistently use the term op_def (NFC). ... and remove unused type aliases. Depends On D119003 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D119125 --- .../mlir/dialects/linalg/opdsl/dump_oplib.py | 4 ++-- .../mlir/dialects/linalg/opdsl/lang/affine.py | 3 --- .../linalg/opdsl/lang/comprehension.py | 3 --- .../mlir/dialects/linalg/opdsl/lang/config.py | 10 ++++---- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 24 +++++++++---------- .../dialects/linalg/opdsl/lang/emitter.py | 1 + 6 files changed, 20 insertions(+), 25 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py index bacc0c302..5a695d621 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -73,10 +73,10 @@ def main(args): # TODO: This class layering is awkward. if isinstance(value, DefinedOpCallable): try: - linalg_config = LinalgOpConfig.from_linalg_op_def(value.model) + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) except Exception as e: raise ValueError( - f"Could not create LinalgOpConfig from {value.model}") from e + f"Could not create LinalgOpConfig from {value.op_def}") from e configs.extend(linalg_config) # Print. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py index 9c1bb3342..038f06834 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -64,9 +64,6 @@ "SymbolDef", ] -# Type aliases. -SymbolPosMap = Dict[str, int] - class AffineBuildState: """Internal state for the AffineExprDef._create impls. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index f6f3e0144..ea25d85aa 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -17,9 +17,6 @@ from .types import * from .yaml_helper import * -# Type aliases. -AffineDimList = Dict[str, _ir.AffineExpr] - class TensorExpression: """An expression that can appear on the RHS of a comprehension.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index fec41decb..59a10998e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -421,18 +421,18 @@ def to_yaml_custom_dict(self): @staticmethod def from_linalg_op_def( - tc_op_def: LinalgOpDef, + op_def: LinalgOpDef, context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: """Expands a LinalgOpDef into corresponding Linalg configured ops.""" # TODO: Many LinalgOpDef patterns need to expand to multiple generics. assert len( - tc_op_def.comprehensions) == 1, "Only one comprehension supported" + op_def.comprehensions) == 1, "Only one comprehension supported" return [ LinalgOpConfig( - tc_op_def.metadata, + op_def.metadata, structured_op=LinalgStructuredOpConfig( - tc_op_def.comprehensions[0], tc_op_def.domain, - tc_op_def.registered_operands.values(), context)), + op_def.comprehensions[0], op_def.domain, + op_def.registered_operands.values(), context)), ] def __repr__(self): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 459b1206a..22ed93490 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -22,12 +22,12 @@ @contextmanager -def bind_op_def(model: LinalgOpDef): +def bind_op_def(op_def: LinalgOpDef): if hasattr(_CONTEXT, "current_op_def"): raise ValueError("Cannot recursively define an operation") - _CONTEXT.current_op_def = model + _CONTEXT.current_op_def = op_def try: - yield model + yield op_def finally: del _CONTEXT.current_op_def @@ -53,9 +53,9 @@ def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: class DefinedOpCallable: """Callable that wraps any defined op function.""" - def __init__(self, op_name: str, model: LinalgOpDef): + def __init__(self, op_name: str, op_def: LinalgOpDef): self.op_name = op_name - self.model = model + self.op_def = op_def def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], outs: StructuredOpOuts, **kwargs): @@ -73,7 +73,7 @@ def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], f" of type bool but got {type(emit_generic)}") op_configs = LinalgOpConfig.from_linalg_op_def( - self.model, context=ir.Context.current) + self.op_def, context=ir.Context.current) if len(op_configs) != 1: # TODO: Support composite ops. @@ -97,7 +97,7 @@ def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], return emit_named_structured_op( op_config.structured_op, self.op_name, - self.model.metadata.cpp_class_name, + self.op_def.metadata.cpp_class_name, *in_values, outs=out_values, **kwargs) @@ -121,7 +121,7 @@ def linalg_structured_op(dsl_func=None, # Camel case it. op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - tc_model = LinalgOpDef( + op_def = LinalgOpDef( name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) # Extract arguments and TensorDefs from the signature. @@ -130,7 +130,7 @@ def linalg_structured_op(dsl_func=None, for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): - tc_model.add_operand(param_name, param_default.operand_def) + op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( f"@linalg_structured_op function parameters must be defaulted as " @@ -138,13 +138,13 @@ def linalg_structured_op(dsl_func=None, f"Found {param_name}: {param_default}") dsl_func_args.append(param_default) - # Invoke the DSL func to finish populating the model. - with bind_op_def(tc_model): + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): dsl_func(*dsl_func_args) # TODO: The returned callable should be an IR emitter but that is not # upstreamed yet. - return DefinedOpCallable(op_name, tc_model) + return DefinedOpCallable(op_name, op_def) def implements(*interfaces: OpInterfaceDef): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 643bcaa5c..e4695f0c9 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -23,6 +23,7 @@ "ValueList", ] +# Type aliases. ValueList = Union[Sequence[Value], OpResultList] From fcfafaf56eac91bc99e3e9747d86c186ae9efca6 Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 14 Feb 2022 12:12:15 +0000 Subject: [PATCH 233/915] [mlir][OpDSL] Add default value to index attributes. Index attributes had no default value, which means the attribute values had to be set on the operation. This revision adds a default parameter to `IndexAttrDef`. After the change, every index attribute has to define a default value. For example, we may define the following strides attribute: ``` ``` When using the operation the default stride is used if the strides attribute is not set. The mechanism is implemented using `DefaultValuedAttr`. Additionally, the revision uses the naming index attribute instead of attribute more consistently, which is a preparation for follow up revisions that will introduce function attributes. Depends On D119125 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D119126 --- .../linalg/opdsl/lang/comprehension.py | 35 +++++---- .../mlir/dialects/linalg/opdsl/lang/config.py | 40 +++++----- .../dialects/linalg/opdsl/lang/emitter.py | 58 +++++++------- .../linalg/opdsl/ops/core_named_ops.py | 76 +++++++++---------- 4 files changed, 112 insertions(+), 97 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index ea25d85aa..4513236b8 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -135,7 +135,7 @@ class OperandKind(Enum): InputTensor = 0 Scalar = 1 OutputTensor = 2 - Attribute = 3 + IndexAttr = 3 class OperandDef: @@ -147,16 +147,18 @@ class OperandDef: def __init__(self, kind: OperandKind, - type_var: TypeVar, + type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, - index_dims: Optional[Sequence[DimDef]] = None): - if not isinstance(type_var, TypeVar): + index_dims: Optional[Sequence[DimDef]] = None, + default_vals : Optional[Sequence[int]] = None): + if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") self.owner = None # type: Optional["LinalgOpDef"] self.type_var = type_var self.size_exprs = size_exprs self.index_dims = index_dims + self.default_vals = default_vals self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int @@ -174,7 +176,7 @@ def __hash__(self): def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " - f"index_dims={self.index_dims})") + f"index_dims={self.index_dims}, default_vals={self.default_vals})") class TensorDef: @@ -202,7 +204,7 @@ def __init__(self, f"got {index_dims}") kind = OperandKind.OutputTensor if output else OperandKind.InputTensor self.operand_def = OperandDef( - kind, type_var, size_exprs=shape, index_dims=index_dims) + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) def __getitem__(self, dims) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" @@ -246,7 +248,7 @@ class ScalarDef(TensorExpression): """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.Scalar, type_var) + self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var) @property def scalar_name(self) -> str: @@ -259,18 +261,25 @@ def to_scalar_expression(self) -> ScalarExpression: class IndexAttrDef: - """Index Attribute definition. + """Index attribute definition. Index attributes provide a way to define and set symbols that can be used in indexing expressions. Every attribute specifies a tuple of symbols that at - compile-time are replaced by integer values. + compile-time are replaced by integer values as well as their default values. """ - def __init__(self, *sizes: SymbolDef): + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): if any(not isinstance(size, SymbolDef) for size in sizes): - raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef but got " - f"{sizes}") - self.operand_def = OperandDef(OperandKind.Attribute, I64, size_exprs=sizes) + raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef " + f"but got {sizes}") + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError(f"IndexAttrDef requires default values of type int " + f"but got {default}") + if len(sizes) != len(default): + raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}") + self.operand_def = OperandDef( + OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) class Comprehension: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 59a10998e..21741252f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -45,10 +45,10 @@ class OperandDefConfig(YAMLObject): def __init__(self, operand_def: OperandDef, shape_map: Optional[_ir.AffineMap] = None, - attribute_map: Optional[_ir.AffineMap] = None): + index_attr_map: Optional[_ir.AffineMap] = None): self.operand_def = operand_def self.shape_map = shape_map # type: Optional[_ir.AffineMap] - self.attribute_map = attribute_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] self.indexing_map = None # type: Optional[_ir.AffineMap] @property @@ -61,24 +61,28 @@ def type_var(self) -> TypeVar: @property def usage(self) -> str: - if self.operand_def.kind == OperandKind.Attribute: - return "IndexAttribute" + if self.operand_def.kind == OperandKind.IndexAttr: + return "IndexAttr" if self.operand_def.kind == OperandKind.OutputTensor: - return "OutputOperand" - return "InputOperand" + return "Output" + return "Input" def to_yaml_custom_dict(self): - self_dict = dict( - name=self.name, usage=self.usage, type_var=self.type_var.name) + self_dict = dict(name=self.name, usage=self.usage) + if self.type_var: + self_dict["type_var"] = self.type_var.name if self.shape_map: self_dict["shape_map"] = _serialize_affine_map(self.shape_map) - if self.attribute_map: - self_dict["attribute_map"] = _serialize_affine_map(self.attribute_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) + if self.operand_def.default_vals: + self_dict["default_vals"] = self.operand_def.default_vals return self_dict def __repr__(self): return (f"OperandDefConfig({self.operand_def}, " - f"shape_map={self.shape_map}, attribute_map={self.attribute_map}, " + f"shape_map={self.shape_map}, " + f"index_attr_map={self.index_attr_map}, " f"indexing_map={self.indexing_map})") @@ -162,7 +166,7 @@ def __init__(self, # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: - if operand.kind == OperandKind.Attribute: + if operand.kind == OperandKind.IndexAttr: collected_attr_defs.append(operand) # Collect all tensors with manual indexing annotation. @@ -210,9 +214,9 @@ def __init__(self, if operand_config.shape_map: operand_config.shape_map = self._normalize_affine_map( operand_config.shape_map, with_dims=False) - if operand_config.attribute_map: - operand_config.attribute_map = self._normalize_affine_map( - operand_config.attribute_map, with_dims=False) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False) # Now for each write use, propagate the indexing maps from the use to the # tensor, ensuring that there are not conflicts. @@ -245,7 +249,7 @@ def __init__(self, # Check all registered tensor and scalar operands have an indexing map. for operand in registered_operands: - if operand.kind == OperandKind.Attribute: + if operand.kind == OperandKind.IndexAttr: continue if not (operand in self.operands and self.operands[operand].indexing_map): raise ValueError(f"Failed to compute an indexing map for operand " @@ -319,9 +323,9 @@ def add_operand(self, operand_def: OperandDef): assert local_state.local_dim_count == 0 affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.Attribute: + if operand_def.kind == OperandKind.IndexAttr: self.operands[operand_def] = OperandDefConfig( - operand_def, attribute_map=affine_map) + operand_def, index_attr_map=affine_map) else: self.operands[operand_def] = OperandDefConfig( operand_def, shape_map=affine_map) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index e4695f0c9..3d3b1889b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -39,15 +39,14 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [arg for arg in all_arg_defs if arg.usage == "InputOperand"] - out_arg_defs = [arg for arg in all_arg_defs if arg.usage == "OutputOperand"] - attr_arg_defs = [arg for arg in all_arg_defs if arg.usage == "IndexAttribute"] + in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"] + out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"] + index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"] # Verify outs is a sequence or a list of results. if not isinstance(outs, (Sequence, OpResultList)): - raise ValueError( - f"Expected named argument outs to have type Sequence or OpResultLis but got {type(outs)}" - ) + raise ValueError(f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}") # Arity validation. if len(ins) != len(in_arg_defs): @@ -60,18 +59,19 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, # Compute a replacement list for all attribute symbols. expressions = [] # type: Sequence[AffineExpr] replacements = [] # type: Sequence[AffineExpr] - for attr in attr_arg_defs: - if attr.name not in attrs: - raise ValueError(f"Expected named argument for the attribute {attr.name}") - attribute_values = attrs.get(attr.name) - if not all(isinstance(value, int) for value in attribute_values): - raise ValueError(f"Attribute {attr.name} needs to be of type " - f"Sequence[int] but got {type(attribute_values)}") - results = attr.attribute_map.results # type: AffineExprList - if len(attribute_values) != len(results): - raise ValueError(f"Attribute {attr.name} has length {len(results)} " - f"but got {len(attribute_values)} values") - for expr, value in zip(results, attribute_values): + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_vals + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError(f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}") + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError(f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values") + for expr, value in zip(results, index_attr_vals): expressions.append(expr) replacements.append(AffineConstantExpr.get(value)) @@ -116,22 +116,24 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, iterator_types_attr = ArrayAttr.get( [StringAttr.get(s) for s in op_config.iterator_types]) - # Compute a dictionary storing all index attributes. - index_attributes = {} # type: Dict[str, DenseElementAttr] - for attr in attr_arg_defs: - attribute_values = attrs.get(attr.name) - array = np.array(attribute_values, dtype=np.int64) - index_attributes[attr.name] = DenseElementsAttr.get(array) + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, - index_attributes, block_arg_types) + index_attrs, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # An operation that accesses only scalars and scalar/rank zero tensors is @@ -182,7 +184,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attributes, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. @@ -195,7 +197,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, # Set the index attributes used to compute the indexing maps. named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - for name, value in index_attributes.items(): + for name, value in index_attrs.items(): named_op.operation.attributes[name] = value linalg.fill_builtin_region(named_op.operation) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 80a8fb6cc..25bd0c3ab 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -224,8 +224,8 @@ def conv_1d_nwc_wcf( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SW), - dilations=IndexAttrDef(S.DW)): + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -244,8 +244,8 @@ def conv_2d_nhwc_hwcf( S.C), K=TensorDef(T2, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -270,8 +270,8 @@ def conv_2d_nhwc_hwcf_q( IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution with zero point offsets. Layout: @@ -297,8 +297,8 @@ def conv_2d_nchw_fchw( S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KH, S.KW), O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -321,8 +321,8 @@ def conv_3d_ndhwc_dhwcf( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -341,8 +341,8 @@ def depthwise_conv_1d_nwc_wc( I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KW, S.IC), O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SW), - dilations=IndexAttrDef(S.DW)): + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs depth-wise 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -362,8 +362,8 @@ def depthwise_conv_2d_nhwc_hwc( S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -385,8 +385,8 @@ def depthwise_conv_2d_nhwc_hwc_q( IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -407,8 +407,8 @@ def depthwise_conv_2d_nhwc_hwcm( S.IC), K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -429,8 +429,8 @@ def depthwise_conv_2d_nhwc_hwcm_q( IZp=ScalarDef(I32), KZp=ScalarDef(I32), O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -451,8 +451,8 @@ def pooling_nhwc_sum( S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -470,8 +470,8 @@ def pooling_nhwc_max( S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -490,8 +490,8 @@ def pooling_nhwc_max_unsigned( S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs unsigned max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -510,8 +510,8 @@ def pooling_nchw_max( S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -531,8 +531,8 @@ def pooling_nhwc_min( S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -551,8 +551,8 @@ def pooling_nhwc_min_unsigned( S.C), K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW), - dilations=IndexAttrDef(S.DH, S.DW)): + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs unsigned min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -571,8 +571,8 @@ def pooling_ndhwc_sum( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -591,8 +591,8 @@ def pooling_ndhwc_max( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -612,8 +612,8 @@ def pooling_ndhwc_min( S.OW * S.SW + S.KW * S.DW, S.C), K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW), - dilations=IndexAttrDef(S.DD, S.DH, S.DW)): + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): """Performs 3D min pooling. Numeric casting is performed on the input operand, promoting it to the same From 263000c72f54712ccdc9c9cb16ec4ad0b763fe99 Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 14 Feb 2022 12:55:48 +0000 Subject: [PATCH 234/915] [mlir][OpDSL] Restructure comprehension.py (NFC). Group and reorder the classed defined by comprehension.py and add type annotations. Depends On D119126 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D119692 --- .../linalg/opdsl/lang/comprehension.py | 509 +++++++++--------- 1 file changed, 263 insertions(+), 246 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 4513236b8..300ea0833 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -8,7 +8,7 @@ represent actual op definitions (i.e. YAML). """ -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from enum import Enum from ..... import ir as _ir @@ -17,6 +17,10 @@ from .types import * from .yaml_helper import * +############################################################################### +# Tensor expression nodes. +############################################################################### + class TensorExpression: """An expression that can appear on the RHS of a comprehension.""" @@ -24,19 +28,18 @@ class TensorExpression: def to_scalar_expression(self) -> ScalarExpression: raise NotImplementedError() - def visit_tensor_exprs(self, callback): + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): """Visits all tensor expression reachable by the expression.""" callback(self) def collect_dim_uses(self, uses: Set["DimDef"]): """Collects all DimDefs reachable through this expression.""" - results = set() - def visit_dim_def(dim_def): + def visit_dim_def(dim_def: AffineExprDef): if isinstance(dim_def, DimDef): uses.add(dim_def) - def visit_affine_exprs(expr): + def visit_affine_exprs(expr: "TensorExpression"): if isinstance(expr, TensorUse): for ind in expr.indices: ind.visit_affine_exprs(visit_dim_def) @@ -49,7 +52,7 @@ def visit_affine_exprs(expr): def collect_tensor_uses(self, uses: Set["TensorUse"]): """Collects all TensorUses reachable through this expression.""" - def visit_tensor_use(expr): + def visit_tensor_use(expr: "TensorExpression"): if isinstance(expr, TensorUse): uses.add(expr) @@ -58,7 +61,7 @@ def visit_tensor_use(expr): def collect_indices(self, indices: Set["index"]): """Collects all index accesses reachable through this expression.""" - def visit_index(expr): + def visit_index(expr: "TensorExpression"): if isinstance(expr, index): indices.add(expr) @@ -67,7 +70,7 @@ def visit_index(expr): def collect_scalar_uses(self, uses: Set["ScalarDef"]): """Collects all ScalarDefs reachable through this expression.""" - def visit_scalar_def(expr): + def visit_scalar_def(expr: "TensorExpression"): if isinstance(expr, ScalarDef): uses.add(expr) @@ -111,26 +114,261 @@ def tensor_name(self) -> str: assert name is not None, "TensorDef not attached" return name - def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) - def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: - """For implicit reductions, computes default reduction dims. - - Assumes that the rhs is the expression being reduced and self is being - reduced into. Any indices referenced on the rhs and not in self are - considered reduction dims and will be ordered as encountered on the rhs. - """ + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. rhs_dims = set() lhs_dims = set() rhs.collect_dim_uses(rhs_dims) self.collect_dim_uses(lhs_dims) return rhs_dims - lhs_dims + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) + def __repr__(self): return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" +class TensorArithFn(TensorExpression): + """Application of an arithmetic function.""" + + def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]): + self.arith_fn = arith_fn + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArithFn(self.arith_fn.fn_name, + *[arg.to_scalar_expression() for arg in self.args + ]).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" + + +class TensorTypeFn(TensorExpression): + """Application of a type conversion function.""" + + def __init__(self, type_fn: "TypeFn", type_var: TypeVar, + arg: TensorExpression): + self.type_fn = type_fn + self.type_var = type_var + self.arg = arg + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + self.arg.to_scalar_expression()).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + self.arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" + + +class TensorReduceFn(TensorExpression): + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", + args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = tuple(args) + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}") + full_args = [self.lhs.to_scalar_expression() + ] + [arg.to_scalar_expression() for arg in self.args] + return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + + +class const(TensorExpression): + """Returns the given constant floating point or integer value.""" + + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) + else: + raise ValueError(f"const requires int or float but got {type(value)}") + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() + + def __repr__(self): + return f"const({self.value})" + + +class index(TensorExpression): + """Returns the iteration index for a given dimension name. + + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ + + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 + + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) + + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() + + def __repr__(self): + return f"index({repr(self.dim)})" + + +############################################################################### +# Function types and function definitions. +############################################################################### + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": + return TensorTypeFn(self, type_var, arg) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") + + +class ArithFnType: + """Arithmetic function. + + An arithmetic function takes one ore more tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, *args) -> "TensorArithFn": + return TensorArithFn(self, args) + + def __repr__(self): + return f"{self.fn_name}" + + +class ArithFn: + """Arithmetic function namespace. + + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ + add = ArithFnType("add") + exp = ArithFnType("exp") + log = ArithFnType("log") + mul = ArithFnType("mul") + max = ArithFnType("max") + min = ArithFnType("min") + sub = ArithFnType("sub") + max_unsigned = ArithFnType("max_unsigned") + min_unsigned = ArithFnType("min_unsigned") + + +class ReduceFnUse: + """Reduction function use. + + A reduction use specifies the reduction function and dimensions. + """ + + def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): + self.arith_fn = arith_fn + self.reduce_dims = reduce_dims + + def __call__(self, *args: TensorExpression): + return TensorReduceFn(self, args) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}" + f"({', '.join(repr(d) for d in self.reduce_dims)})") + + +class ReduceFnType: + """Reduction function. + + An arithmetic function that reduces its RHS into its LHS. + """ + + def __init__(self, arith_fn: ArithFnType): + if not isinstance(arith_fn, ArithFnType): + raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") + self.arith_fn = arith_fn + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.arith_fn, *reduce_dims) + + def __repr__(self): + return (f"reduce_{self.arith_fn.fn_name}") + + +class ReduceFn: + add = ReduceFnType(ArithFn.add) + mul = ReduceFnType(ArithFn.mul) + max = ReduceFnType(ArithFn.max) + min = ReduceFnType(ArithFn.min) + max_unsigned = ReduceFnType(ArithFn.max_unsigned) + min_unsigned = ReduceFnType(ArithFn.min_unsigned) + + +############################################################################### +# Operand definitions. +############################################################################### + + class OperandKind(Enum): InputTensor = 0 Scalar = 1 @@ -150,7 +388,7 @@ def __init__(self, type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, index_dims: Optional[Sequence[DimDef]] = None, - default_vals : Optional[Sequence[int]] = None): + default_vals: Optional[Sequence[int]] = None): if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") @@ -206,7 +444,7 @@ def __init__(self, self.operand_def = OperandDef( kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) - def __getitem__(self, dims) -> TensorUse: + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: assert self.operand_def.owner, "TensorDef is not attached to an op" state = AffineBuildState( global_state=self.operand_def.owner._affine_state, @@ -225,7 +463,7 @@ def __getitem__(self, dims) -> TensorUse: exprs.append(expr_def) return TensorUse(self.operand_def, exprs) - def __setitem__(self, dims, value): + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): """Creates a new 1:1 comprehension by binding this tensor to an expression. Note that due to the way assignment works in Python, we have to capture @@ -282,6 +520,11 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]): OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) +############################################################################### +# Operation definition. +############################################################################### + + class Comprehension: """Represents a single comprehension.""" @@ -320,232 +563,6 @@ def __repr__(self): return f"{defs_repr} = {values_repr}" -class TypeFnType: - """Type conversion function. - - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. - """ - - def __init__(self, fn_name: str): - self.fn_name = fn_name - - def __call__(self, type_var: TypeVar, - arg: TensorExpression) -> "TensorTypeFn": - return TensorTypeFn(self, type_var, arg) - - def __repr__(self): - return f"{self.fn_name}" - - -class TypeFn: - """Type conversion function namespace. - - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned - (`cast_unsigned`) values. - - Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast = TypeFnType("cast") - cast_unsigned = TypeFnType("cast_unsigned") - - -class ArithFnType: - """Arithmetic function. - - An arithmetic function takes one ore more tensor expressions and returns the - function evaluation result. - """ - - def __init__(self, fn_name: str): - self.fn_name = fn_name - - def __call__(self, *args) -> "TensorArithFn": - return TensorArithFn(self, args) - - def __repr__(self): - return f"{self.fn_name}" - - -class ArithFn: - """Arithmetic function namespace. - - As the integer types are signless, signedness is implement by different - functions that treat integers as signed or unsigned values. - - Examples: - - max -> `arith.MaxSIOp` - - max_unsinged -> `arith.MaxUIOp` - """ - add = ArithFnType("add") - exp = ArithFnType("exp") - log = ArithFnType("log") - mul = ArithFnType("mul") - max = ArithFnType("max") - min = ArithFnType("min") - sub = ArithFnType("sub") - max_unsigned = ArithFnType("max_unsigned") - min_unsigned = ArithFnType("min_unsigned") - - -class ReduceFnUse: - """Reduction function use. - - A reduction use specifies the reduction function and dimensions. - """ - - def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): - self.arith_fn = arith_fn - self.reduce_dims = reduce_dims - - def __call__(self, *args: TensorExpression): - return TensorReduceFn(self, args) - - def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}" - f"({', '.join(repr(d) for d in self.reduce_dims)})") - - -class ReduceFnType: - """Reduction function. - - An arithmetic function that reduces its RHS into its LHS. - """ - - def __init__(self, arith_fn: ArithFnType): - if not isinstance(arith_fn, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") - self.arith_fn = arith_fn - - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.arith_fn, *reduce_dims) - - def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}") - - -class ReduceFn: - add = ReduceFnType(ArithFn.add) - mul = ReduceFnType(ArithFn.mul) - max = ReduceFnType(ArithFn.max) - min = ReduceFnType(ArithFn.min) - max_unsigned = ReduceFnType(ArithFn.max_unsigned) - min_unsigned = ReduceFnType(ArithFn.min_unsigned) - - -class TensorArithFn(TensorExpression): - """Application of an arithmetic function.""" - - def __init__(self, arith_fn: ArithFnType, args: Sequence[TensorExpression]): - self.arith_fn = arith_fn - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArithFn(self.arith_fn.fn_name, - *[arg.to_scalar_expression() for arg in self.args - ]).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" - - -class TensorTypeFn(TensorExpression): - """Application of a type conversion function.""" - - def __init__(self, type_fn: TypeFn, type_var: TypeVar, arg: TensorExpression): - self.type_fn = type_fn - self.type_var = type_var - self.arg = arg - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarTypeFn(self.type_fn.fn_name, self.type_var, - self.arg.to_scalar_expression()).expr() - - def visit_tensor_exprs(self, callback): - super().visit_tensor_exprs(callback) - self.arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" - - -class const(TensorExpression): - """Returns the given constant floating point or integer value.""" - - def __init__(self, value: Any): - with _ir.Context(): - if isinstance(value, float): - self.value = str(_ir.FloatAttr.get_f64(float(value))) - elif isinstance(value, int): - self.value = str( - _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) - else: - raise ValueError(f"const requires int or float but got {type(value)}") - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.value).expr() - - def __repr__(self): - return f"const({self.value})" - - -class index(TensorExpression): - """Returns the iteration index for a given dimension name. - - Resolves the given dimension name to obtain its position in the iteration - domain of the operation. - """ - - def __init__(self, dim: DimDef): - self.dim_def = dim - self.dim = -1 - - def resolve_dimension_name(self, affine_state: AffineBuildState): - self.dim = affine_state.get_dim(self.dim_def.dimname) - - def to_scalar_expression(self) -> ScalarExpression: - assert self.dim != -1, "Dimension name not resolved" - return ScalarIndex(self.dim).expr() - - def __repr__(self): - return f"index({repr(self.dim)})" - - -class TensorReduceFn(TensorExpression): - """Application of a reduction function. - - This captures the lhs (initial value) separately from the rhs. - """ - - def __init__(self, reduce_use: ReduceFnUse, args: Sequence[TensorExpression]): - self.reduce_use = reduce_use - self.lhs = None # type: Optional[TensorUse] - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - if self.lhs is None: - raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " - f"bound to its lhs: {self}") - full_args = [self.lhs.to_scalar_expression() - ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() - - def visit_tensor_exprs(self, callback): - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" - - class OpInterfaceDef: """An interface that an op implements.""" From 0415c838d9958f02e4e52dbe90093c1da7715aea Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 14 Feb 2022 13:02:11 +0000 Subject: [PATCH 235/915] [mlir][linalg] Add attributes to region builder (NFC). Adapt the region builder signature to hand in the attributes of the created ops. The revision is a preparation step the support named ops that need access to the operation attributes during op creation. Depends On D119692 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D119693 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 8862b6b15..bfb3313d1 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - fun(b, *body); + fun(b, *body, op->getAttrs()); } MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From 3251b8e1db60c7f50a24a4923bb5d581c3fc85c8 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sun, 13 Feb 2022 22:49:28 -0800 Subject: [PATCH 236/915] [mlir][python] Directly implement sequence protocol on Sliceable. * While annoying, this is the only way to get C++ exception handling out of the happy path for normal iteration. * Implements sq_length and sq_item for the sequence protocol (used for iteration, including list() construction). * Implements mp_subscript for general use (i.e. foo[1] and foo[1:1]). * For constructing a `list(op.results)`, this reduces the time from ~4-5us to ~1.5us on my machine (give or take measurement overhead) and eliminates C++ exceptions, which is a worthy goal in itself. * Compared to a baseline of similar construction of a three-integer list, which takes 450ns (might just be measuring function call overhead). * See issue discussed on the pybind side: https://github.com/pybind/pybind11/issues/2842 Differential Revision: https://reviews.llvm.org/D119691 --- mlir/lib/Bindings/Python/PybindUtils.h | 102 ++++++++++++++++++------- 1 file changed, 74 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 75a72371e..e791ba8e2 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -207,6 +207,8 @@ struct PySinglePartStringAccumulator { /// constructs a new instance of the derived pseudo-container with the /// given slice parameters (to be forwarded to the Sliceable constructor). /// +/// The getNumElements() and getElement(intptr_t) callbacks must not throw. +/// /// A derived class may additionally define: /// - a `static void bindDerived(ClassTy &)` method to bind additional methods /// the python class. @@ -215,49 +217,53 @@ class Sliceable { protected: using ClassTy = pybind11::class_; + // Transforms `index` into a legal value to access the underlying sequence. + // Returns <0 on failure. intptr_t wrapIndex(intptr_t index) { if (index < 0) index = length + index; - if (index < 0 || index >= length) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); - } + if (index < 0 || index >= length) + return -1; return index; } -public: - explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) - : startIndex(startIndex), length(length), step(step) { - assert(length >= 0 && "expected non-negative slice length"); - } - - /// Returns the length of the slice. - intptr_t dunderLen() const { return length; } - /// Returns the element at the given slice index. Supports negative indices - /// by taking elements in inverse order. Throws if the index is out of bounds. - ElementTy dunderGetItem(intptr_t index) { + /// by taking elements in inverse order. Returns a nullptr object if out + /// of bounds. + pybind11::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); + if (index < 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; + } // Compute the linear index given the current slice properties. int linearIndex = index * step + startIndex; assert(linearIndex >= 0 && linearIndex < static_cast(this)->getNumElements() && "linear index out of bounds, the slice is ill-formed"); - return static_cast(this)->getElement(linearIndex); + return pybind11::cast( + static_cast(this)->getElement(linearIndex)); } /// Returns a new instance of the pseudo-container restricted to the given - /// slice. - Derived dunderGetItemSlice(pybind11::slice slice) { + /// slice. Returns a nullptr object on failure. + pybind11::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; - if (!slice.compute(dunderLen(), &start, &stop, &extraStep, &sliceLength)) { - throw python::SetPyError(PyExc_IndexError, - "attempt to access out of bounds"); + if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, + &sliceLength) != 0) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return {}; } - return static_cast(this)->slice(startIndex + start * step, - sliceLength, step * extraStep); + return pybind11::cast(static_cast(this)->slice( + startIndex + start * step, sliceLength, step * extraStep)); + } + +public: + explicit Sliceable(intptr_t startIndex, intptr_t length, intptr_t step) + : startIndex(startIndex), length(length), step(step) { + assert(length >= 0 && "expected non-negative slice length"); } /// Returns a new vector (mapped to Python list) containing elements from two @@ -267,10 +273,10 @@ class Sliceable { std::vector elements; elements.reserve(length + other.length); for (intptr_t i = 0; i < length; ++i) { - elements.push_back(dunderGetItem(i)); + elements.push_back(static_cast(this)->getElement(i)); } for (intptr_t i = 0; i < other.length; ++i) { - elements.push_back(other.dunderGetItem(i)); + elements.push_back(static_cast(this)->getElement(i)); } return elements; } @@ -279,11 +285,51 @@ class Sliceable { static void bind(pybind11::module &m) { auto clazz = pybind11::class_(m, Derived::pyClassName, pybind11::module_local()) - .def("__len__", &Sliceable::dunderLen) - .def("__getitem__", &Sliceable::dunderGetItem) - .def("__getitem__", &Sliceable::dunderGetItemSlice) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); + + // Manually implement the sequence protocol via the C API. We do this + // because it is approx 4x faster than via pybind11, largely because that + // formulation requires a C++ exception to be thrown to detect end of + // sequence. + // Since we are in a C-context, any C++ exception that happens here + // will terminate the program. There is nothing in this implementation + // that should throw in a non-terminal way, so we forgo further + // exception marshalling. + // See: https://github.com/pybind/pybind11/issues/2842 + auto heap_type = reinterpret_cast(clazz.ptr()); + assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && + "must be heap type"); + heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { + auto self = pybind11::cast(rawSelf); + return self->length; + }; + // sq_item is called as part of the sequence protocol for iteration, + // list construction, etc. + heap_type->as_sequence.sq_item = + +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { + auto self = pybind11::cast(rawSelf); + return self->getItem(index).release().ptr(); + }; + // mp_subscript is used for both slices and integer lookups. + heap_type->as_mapping.mp_subscript = + +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { + auto self = pybind11::cast(rawSelf); + Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); + if (!PyErr_Occurred()) { + // Integer indexing. + return self->getItem(index).release().ptr(); + } + PyErr_Clear(); + + // Assume slice-based indexing. + if (PySlice_Check(rawSubscript)) { + return self->getItemSlice(rawSubscript).release().ptr(); + } + + PyErr_SetString(PyExc_ValueError, "expected integer or slice"); + return nullptr; + }; } /// Hook for derived classes willing to bind more methods. From 6a308b4c8dcfefdc4cfc975f7f64f2a9e559da85 Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Thu, 17 Feb 2022 10:06:16 +0100 Subject: [PATCH 237/915] [MLIR][PDL] Fix typo (NFC) --- mlir/include/mlir-c/Dialect/PDL.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h index 1b1528999..8bd7976e2 100644 --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -71,4 +71,4 @@ MLIR_CAPI_EXPORTED MlirType mlirPDLValueTypeGet(MlirContext ctx); } #endif -#endif // MLIR_C_DIALECT_QUANT_H +#endif // MLIR_C_DIALECT_PDL_H From 23cffc016219a06edc8fd91a75036953dca43ec3 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 19 Jan 2022 13:43:24 +0100 Subject: [PATCH 238/915] [mlir] Annotate methods on a correct class in PybindAdaptors.h The `.def` and `.def_property_readonly` functions in PybindAdaptors.h should construct the functions as method of the current class rather than as method of pybind11:none(), which is an object and not even a class. Depends On D117658 Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D117659 --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 9d5a512a4..661ed48f9 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -249,7 +249,7 @@ class pure_subclass { template pure_subclass &def(const char *name, Func &&f, const Extra &... extra) { py::cpp_function cf( - std::forward(f), py::name(name), py::is_method(py::none()), + std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); thisClass.attr(cf.name()) = cf; return *this; @@ -259,7 +259,7 @@ class pure_subclass { pure_subclass &def_property_readonly(const char *name, Func &&f, const Extra &... extra) { py::cpp_function cf( - std::forward(f), py::name(name), py::is_method(py::none()), + std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); auto builtinProperty = py::reinterpret_borrow((PyObject *)&PyProperty_Type); From 87d1a3dbaf98cd4b921aefced832d2c2a92e726d Mon Sep 17 00:00:00 2001 From: fuzzypixelz Date: Mon, 21 Feb 2022 07:53:27 -0800 Subject: [PATCH 239/915] [MLIR] replace C++ function type defintion in the C API's Interfaces.h Clearly this something of a typo, and it obviously doesn't even compile. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120247 --- mlir/include/mlir-c/Interfaces.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index 7ab6b8af3..233f828b9 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -48,7 +48,7 @@ MLIR_CAPI_EXPORTED MlirTypeID mlirInferTypeOpInterfaceTypeID(); /// transferring ownership to the caller. The first argument is the number of /// consecutive elements pointed to by the second argument. The third argument /// is an opaque pointer forwarded to the callback by the caller. -using MlirTypesCallback = void (*)(intptr_t, MlirType *, void *); +typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); /// Infers the return types of the operation identified by its canonical given /// the arguments that will be supplied to its generic builder. Calls `callback` From 460e87d89243be9f9a6949e6c4e1054c08f588c8 Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Tue, 22 Feb 2022 22:27:54 -0500 Subject: [PATCH 240/915] [mlir][NFC] Use options struct in ExecutionEngine::create Its number of optional parameters has grown too large, which makes adding new optional parameters quite a chore. Fix this by using an options struct. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D120380 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 604cc4522..4c5553087 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -50,9 +50,11 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, auto llvmOptLevel = static_cast(optLevel); auto transformer = mlir::makeLLVMPassesTransformer( /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); - auto jitOrError = - ExecutionEngine::create(unwrap(op), /*llvmModuleBuilder=*/{}, transformer, - llvmOptLevel, libPaths); + ExecutionEngineOptions jitOptions; + jitOptions.transformer = transformer; + jitOptions.jitCodeGenOptLevel = llvmOptLevel; + jitOptions.sharedLibPaths = libPaths; + auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); if (!jitOrError) { consumeError(jitOrError.takeError()); return MlirExecutionEngine{nullptr}; From 7595641154f963679707237259b0e05aa1220660 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 24 Feb 2022 10:21:40 +0100 Subject: [PATCH 241/915] [mlir][python] Support more types in IntegerAttr.value Previously only accessing values for `index` and signless int types would work; signed and unsigned ints would hit an assert in `IntegerAttr::getInt`. This exposes `IntegerAttr::get{S,U}Int` to the C API and calls the appropriate function from the python bindings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120194 --- mlir/include/mlir-c/BuiltinAttributes.h | 10 +++++++++- mlir/lib/Bindings/Python/IRAttributes.cpp | 9 +++++++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 8 ++++++++ 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 973b7e994..bb4431f7b 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -125,9 +125,17 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value); /// Returns the value stored in the given integer attribute, assuming the value -/// fits into a 64-bit integer. +/// is of signless type and fits into a signed 64-bit integer. MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr); +/// Returns the value stored in the given integer attribute, assuming the value +/// is of signed type and fits into a signed 64-bit integer. +MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr); + +/// Returns the value stored in the given integer attribute, assuming the value +/// is of unsigned type and fits into an unsigned 64-bit integer. +MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 5d87641c3..bef3b95a2 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -258,8 +258,13 @@ class PyIntegerAttribute : public PyConcreteAttribute { "Gets an uniqued integer attribute associated to a type"); c.def_property_readonly( "value", - [](PyIntegerAttribute &self) { - return mlirIntegerAttrGetValueInt(self); + [](PyIntegerAttribute &self) -> py::int_ { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); }, "Returns the value of the integer attribute"); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 7b718da88..9ea277b74 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -129,6 +129,14 @@ int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { return unwrap(attr).cast().getInt(); } +int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { + return unwrap(attr).cast().getSInt(); +} + +uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { + return unwrap(attr).cast().getUInt(); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// From a0fcc77f1264eebe437b36ed29453e92e7c64a4a Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 25 Feb 2022 08:12:34 +0000 Subject: [PATCH 242/915] [mlir][OpDSL] Add type function attributes. Previously, OpDSL operation used hardcoded type conversion operations (cast or cast_unsigned). Supporting signed and unsigned casts thus meant implementing two different operations. Type function attributes allow us to define a single operation that has a cast type function attribute which at operation instantiation time may be set to cast or cast_unsigned. We may for example, defina a matmul operation with a cast argument: ``` @linalg_structured_op def matmul(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), cast=TypeFnAttrDef(default=TypeFn.cast)): C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) ``` When instantiating the operation the attribute may be set to the desired cast function: ``` linalg.matmul(lhs, rhs, outs=[out], cast=TypeFn.cast_unsigned) ``` The revsion introduces a enum in the Linalg dialect that maps one-by-one to the type functions defined by OpDSL. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D119718 --- .../linalg/opdsl/lang/comprehension.py | 108 +++++++++++++----- .../mlir/dialects/linalg/opdsl/lang/config.py | 34 +++--- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 3 +- .../dialects/linalg/opdsl/lang/emitter.py | 67 ++++++++--- .../dialects/linalg/opdsl/lang/scalar_expr.py | 23 ++-- .../linalg/opdsl/ops/core_named_ops.py | 5 +- 6 files changed, 165 insertions(+), 75 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 300ea0833..68c08809e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -111,7 +111,7 @@ def to_scalar_expression(self) -> ScalarExpression: @property def tensor_name(self) -> str: name = self.operand_def.name - assert name is not None, "TensorDef not attached" + assert name is not None, "TensorDef not registered with an op" return name def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: @@ -129,7 +129,8 @@ def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): - return f"{self.tensor_name}[{', '.join([repr(i) for i in self.indices])}]" + return (f"{self.operand_def.name}" + f"[{', '.join([repr(i) for i in self.indices])}]") class TensorArithFn(TensorExpression): @@ -156,14 +157,22 @@ def __repr__(self): class TensorTypeFn(TensorExpression): """Application of a type conversion function.""" - def __init__(self, type_fn: "TypeFn", type_var: TypeVar, + def __init__(self, type_fn: Optional["TypeFn"], + operand_def: Optional["OperandDef"], type_var: TypeVar, arg: TensorExpression): + if bool(type_fn) + bool(operand_def) != 1: + raise ValueError("Either 'type_fn' or 'operand_def' must be specified") self.type_fn = type_fn + self.operand_def = operand_def self.type_var = type_var self.arg = arg def to_scalar_expression(self) -> ScalarExpression: - return ScalarTypeFn(self.type_fn.fn_name, self.type_var, + if self.operand_def: + assert self.operand_def.name, "TypeFnAttr not registered with an op" + fn_name = self.type_fn.fn_name if self.type_fn else None + attr_name = self.operand_def.name if self.operand_def else None + return ScalarTypeFn(fn_name, attr_name, self.type_var, self.arg.to_scalar_expression()).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): @@ -171,7 +180,8 @@ def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): self.arg.visit_tensor_exprs(callback) def __repr__(self): - return f"{repr(self.type_fn)}({self.type_var}, {self.arg})" + return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]" + f"({self.type_var}, {self.arg})") class TensorReduceFn(TensorExpression): @@ -260,7 +270,7 @@ def __init__(self, fn_name: str): self.fn_name = fn_name def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": - return TensorTypeFn(self, type_var, arg) + return TensorTypeFn(self, None, type_var, arg) def __repr__(self): return f"{self.fn_name}" @@ -370,10 +380,11 @@ class ReduceFn: class OperandKind(Enum): - InputTensor = 0 - Scalar = 1 - OutputTensor = 2 - IndexAttr = 3 + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + TYPE_FN_ATTR = 4 class OperandDef: @@ -388,7 +399,8 @@ def __init__(self, type_var: Optional[TypeVar] = None, size_exprs: Optional[Sequence[AffineExprDef]] = None, index_dims: Optional[Sequence[DimDef]] = None, - default_vals: Optional[Sequence[int]] = None): + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None): if type_var and not isinstance(type_var, TypeVar): raise ValueError( f"OperandDef requires a TypeVar but got {repr(type_var)}") @@ -396,25 +408,40 @@ def __init__(self, self.type_var = type_var self.size_exprs = size_exprs self.index_dims = index_dims - self.default_vals = default_vals + self.default_indices = default_indices + self.default_fn = default_fn self.kind = kind self.name = None # type: Optional[str] self.registered_index = -1 # type: int def attach(self, index: int, name: str, owner: "LinalgOpDef"): if self.owner: - raise ValueError(f"OperandDef already registered with op: {self}") + raise ValueError(f"OperandDef already registered with an op: {self}") self.registered_index = index self.name = name self.owner = owner + def is_input(self) -> bool: + return (self.kind == OperandKind.SCALAR or + self.kind == OperandKind.INPUT_TENSOR) + + def is_tensor(self) -> bool: + return (self.kind == OperandKind.INPUT_TENSOR or + self.kind == OperandKind.OUTPUT_TENSOR) + + def is_attribute(self) -> bool: + return (self.kind == OperandKind.INDEX_ATTR or + self.kind == OperandKind.TYPE_FN_ATTR) + def __hash__(self): return hash(id(self)) def __repr__(self): return (f"{self.name}:OperandDef(kind={self.kind.name}, " - f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), " - f"index_dims={self.index_dims}, default_vals={self.default_vals})") + f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " + f"index_dims={self.index_dims}, " + f"default_indices={self.default_indices}, " + f"default_fn={self.default_fn})") class TensorDef: @@ -440,12 +467,12 @@ def __init__(self, if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): raise ValueError(f"TensorDef requires index dims of type DimDef but " f"got {index_dims}") - kind = OperandKind.OutputTensor if output else OperandKind.InputTensor + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR self.operand_def = OperandDef( kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: - assert self.operand_def.owner, "TensorDef is not attached to an op" + assert self.operand_def.owner, "TensorDef is not registered with an op" state = AffineBuildState( global_state=self.operand_def.owner._affine_state, allow_new_symbols=False) @@ -486,12 +513,12 @@ class ScalarDef(TensorExpression): """ def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.Scalar, type_var=type_var) + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) @property def scalar_name(self) -> str: name = self.operand_def.name - assert name is not None, "ScalarDef not attached" + assert name is not None, "ScalarDef not registered with an op" return name def to_scalar_expression(self) -> ScalarExpression: @@ -517,7 +544,26 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]): raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " f"but got {len(default)}") self.operand_def = OperandDef( - OperandKind.IndexAttr, size_exprs=sizes, default_vals=default) + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) + + +class TypeFnAttrDef: + """Type conversion function attribute definition. + + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn: + return TensorTypeFn(None, self.operand_def, type_var, arg) ############################################################################### @@ -615,17 +661,21 @@ def add_operand(self, name: str, operand: OperandDef): if name in self.registered_operands: raise ValueError(f"The operand {name} is already registered " f"to {self.registered_operands['name']}") + structured_op_methods = [ + "inputs", "outputs", "result_tensors", "region", "iterator_types", + "indexing_maps", "getRegionBuilder", "getLibraryCallName" + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError(f"The attribute name {name} conflicts with a structured " + f"op method name") # Ensure output tensors are registered after input tensors and scalars and # attributes are registered after all other operand types. - registered_kinds = [ - operand.kind.value for operand in self.registered_operands.values() - ] - if registered_kinds: - maximum = max(registered_kinds) - if maximum > operand.kind.value and maximum > OperandKind.Scalar.value: - raise ValueError( - f"The operand {name} of kind {operand.kind.name} is registered " - f"after an operand of kind {OperandKind(maximum).name}") + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values()): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values()): + raise ValueError(f"Output {name} registered after an attribute") operand.attach(len(self.registered_operands), name, self) self.registered_operands[name] = operand diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 21741252f..12b168de1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -56,27 +56,25 @@ def name(self) -> str: return self.operand_def.name @property - def type_var(self) -> TypeVar: - return self.operand_def.type_var + def kind(self) -> OperandKind: + return self.operand_def.kind @property - def usage(self) -> str: - if self.operand_def.kind == OperandKind.IndexAttr: - return "IndexAttr" - if self.operand_def.kind == OperandKind.OutputTensor: - return "Output" - return "Input" + def type_var(self) -> TypeVar: + return self.operand_def.type_var def to_yaml_custom_dict(self): - self_dict = dict(name=self.name, usage=self.usage) + self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) if self.type_var: self_dict["type_var"] = self.type_var.name if self.shape_map: self_dict["shape_map"] = _serialize_affine_map(self.shape_map) if self.index_attr_map: self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) - if self.operand_def.default_vals: - self_dict["default_vals"] = self.operand_def.default_vals + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn return self_dict def __repr__(self): @@ -166,7 +164,7 @@ def __init__(self, # Collect all attribute definitions. collected_attr_defs = list() for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): collected_attr_defs.append(operand) # Collect all tensors with manual indexing annotation. @@ -244,12 +242,12 @@ def __init__(self, # Set the indexing map of all scalar uses to the empty map. for operand_config in self.operands.values(): - if operand_config.operand_def.kind == OperandKind.Scalar: + if operand_config.operand_def.kind == OperandKind.SCALAR: operand_config.indexing_map = self._get_scalar_map() # Check all registered tensor and scalar operands have an indexing map. for operand in registered_operands: - if operand.kind == OperandKind.IndexAttr: + if operand.is_attribute(): continue if not (operand in self.operands and self.operands[operand].indexing_map): raise ValueError(f"Failed to compute an indexing map for operand " @@ -311,7 +309,8 @@ def get_type(symbolic_name, position): def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if operand_def.kind == OperandKind.Scalar: + if (operand_def.kind == OperandKind.SCALAR or + operand_def.kind == OperandKind.TYPE_FN_ATTR): self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: @@ -323,7 +322,7 @@ def add_operand(self, operand_def: OperandDef): assert local_state.local_dim_count == 0 affine_map = _ir.AffineMap.get( dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.IndexAttr: + if operand_def.kind == OperandKind.INDEX_ATTR: self.operands[operand_def] = OperandDefConfig( operand_def, index_attr_map=affine_map) else: @@ -429,8 +428,7 @@ def from_linalg_op_def( context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: """Expands a LinalgOpDef into corresponding Linalg configured ops.""" # TODO: Many LinalgOpDef patterns need to expand to multiple generics. - assert len( - op_def.comprehensions) == 1, "Only one comprehension supported" + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" return [ LinalgOpConfig( op_def.metadata, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 22ed93490..99ce71366 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -129,7 +129,8 @@ def linalg_structured_op(dsl_func=None, sig = inspect.signature(dsl_func) for param_name, param in sig.parameters.items(): param_default = param.default - if isinstance(param_default, (TensorDef, ScalarDef, IndexAttrDef)): + if isinstance(param_default, + (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)): op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 3d3b1889b..fc8c13bfe 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -37,11 +37,21 @@ def isa(cls: Type, ty: Type): def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, - **attrs: Sequence[int]): + **attrs: Union[Sequence[int], TypeFnType]): all_arg_defs = op_config.ordered_operands - in_arg_defs = [d for d in all_arg_defs if d.usage == "Input"] - out_arg_defs = [d for d in all_arg_defs if d.usage == "Output"] - index_attr_arg_defs = [d for d in all_arg_defs if d.usage == "IndexAttr"] + in_arg_defs = [ + d for d in all_arg_defs + if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR + ] + out_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR + ] + index_attr_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR + ] + type_fn_attr_arg_defs = [ + d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR + ] # Verify outs is a sequence or a list of results. if not isinstance(outs, (Sequence, OpResultList)): @@ -56,11 +66,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " f"{len(outs)} for {op_config}") - # Compute a replacement list for all attribute symbols. + # Compute a replacement list for all index attribute symbols. expressions = [] # type: Sequence[AffineExpr] replacements = [] # type: Sequence[AffineExpr] for index_attr in index_attr_arg_defs: - index_attr_vals = index_attr.operand_def.default_vals + index_attr_vals = index_attr.operand_def.default_indices if index_attr.name in attrs: index_attr_vals = attrs.get(index_attr.name) assert index_attr_vals, "Index attribute has no value" @@ -125,15 +135,29 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, array = np.array(index_attr_vals, dtype=np.int64) index_attrs[index_attr.name] = DenseElementsAttr.get(array) + # Compute the type function attribute mapping. + type_fn_attr_mapping = {} + for type_fn_attr in type_fn_attr_arg_defs: + attr_val = type_fn_attr.operand_def.default_fn + if type_fn_attr.name in attrs: + type_fn = attrs.get(type_fn_attr.name) + if not isinstance(type_fn, TypeFnType): + raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}") + attr_val = type_fn.fn_name + assert attr_val, "Type function attribute has no value" + type_fn_attr_mapping[type_fn_attr.name] = attr_val + return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, - index_attrs, block_arg_types) + type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, + type_fn_attr_mapping, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # An operation that accesses only scalars and scalar/rank zero tensors is @@ -147,10 +171,9 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, tensor_map = AffineMap.get_identity(rank) indexing_maps = [] for arg_def in all_arg_defs: - if arg_def.operand_def.kind == OperandKind.Scalar: + if arg_def.operand_def.kind == OperandKind.SCALAR: indexing_maps.append(scalar_map) - if (arg_def.operand_def.kind == OperandKind.InputTensor or - arg_def.operand_def.kind == OperandKind.OutputTensor): + if arg_def.operand_def.is_tensor(): indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) @@ -169,7 +192,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, block = generic_op.regions[0].blocks.append(*block_arg_types) block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping) + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, + type_fn_attr_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) @@ -184,7 +208,8 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, block_arg_types = \ + indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) # If we get here, there must exist a builtin class `op_class_name`. @@ -200,6 +225,11 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, for name, value in index_attrs.items(): named_op.operation.attributes[name] = value + # Set the type function attributes. + for name, value in type_fn_attr_mapping.items(): + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.type_fn<{value}>") + linalg.fill_builtin_region(named_op.operation) if len(result_types) == 1: @@ -212,9 +242,11 @@ class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value]): + block_arg_mapping: Dict[str, Value], + type_fn_attr_mapping: Dict[str, str]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping + self.type_fn_attr_mapping = type_fn_attr_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -245,7 +277,10 @@ def expression(self, expr: ScalarExpression) -> Value: ] return fn(*operand_values) elif expr.type_fn: - fn = self._get_function(f"_typefn_{expr.type_fn.fn_name}") + fn_name = expr.type_fn.fn_name + if expr.type_fn.attr_name: + fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name] + fn = self._get_function(f"_typefn_{fn_name}") operand = self.expression(expr.type_fn.operand) return fn(expr.type_fn.type_var.name, operand) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index 2a30e6e78..af21b40cf 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -46,9 +46,10 @@ def __repr__(self): class ScalarTypeFn: """A type of ScalarExpression that applies a type conversion function.""" - def __init__(self, fn_name: str, type_var: TypeVar, - operand: "ScalarExpression"): + def __init__(self, fn_name: Optional[str], attr_name: Optional[str], + type_var: TypeVar, operand: "ScalarExpression"): self.fn_name = fn_name + self.attr_name = attr_name self.type_var = type_var self.operand = operand @@ -56,7 +57,8 @@ def expr(self) -> "ScalarExpression": return ScalarExpression(type_fn=self) def __repr__(self): - return f"ScalarTypeFn<{self.fn_name}>({self.type_var}, {self.operand})" + return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>" + f"({self.type_var}, {self.operand})") class ScalarArg: @@ -138,12 +140,15 @@ def to_yaml_custom_dict(self): # Note that even though operands must be arity 1, we write it the # same way as for apply because it allows handling code to be more # generic vs having a special form. - return dict( - type_fn=dict( - fn_name=self.type_fn.fn_name, - type_var=self.type_fn.type_var.name, - operands=[self.type_fn.operand], - )) + type_fn_dict = dict( + type_var=self.type_fn.type_var.name, + operands=[self.type_fn.operand], + ) + if self.type_fn.fn_name: + type_fn_dict["fn_name"] = self.type_fn.fn_name + if self.type_fn.attr_name: + type_fn_dict["attr_name"] = self.type_fn.attr_name + return dict(type_fn=type_fn_dict) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 25bd0c3ab..db63a0705 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -10,7 +10,8 @@ def matmul( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -18,7 +19,7 @@ def matmul( """ domain(D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast(U, A[D.m, D.k]) * TypeFn.cast(U, B[D.k, D.n]) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op From 95ff45491868682743fdd0750f64ded06bb2dfdc Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 25 Feb 2022 15:04:38 +0000 Subject: [PATCH 243/915] [mlir][OpDSL] Refactor function handling. Prepare the OpDSL function handling to introduce more function classes. A follow up commit will split ArithFn into UnaryFn and BinaryFn. This revision prepares the split by adding a function kind enum to handle different function types using a single class on the various levels of the stack (for example, there is now one TensorFn and one ScalarFn). Depends On D119718 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120108 --- .../linalg/opdsl/lang/comprehension.py | 77 +++++++---------- .../dialects/linalg/opdsl/lang/emitter.py | 18 ++-- .../dialects/linalg/opdsl/lang/scalar_expr.py | 86 +++++++------------ 3 files changed, 73 insertions(+), 108 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 68c08809e..d26aa0770 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -133,55 +133,36 @@ def __repr__(self): f"[{', '.join([repr(i) for i in self.indices])}]") -class TensorArithFn(TensorExpression): - """Application of an arithmetic function.""" +class TensorFn(TensorExpression): + """Application of a tensor function.""" - def __init__(self, arith_fn: "ArithFnType", args: Sequence[TensorExpression]): - self.arith_fn = arith_fn - self.args = tuple(args) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArithFn(self.arith_fn.fn_name, - *[arg.to_scalar_expression() for arg in self.args - ]).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.arith_fn)}({', '.join(repr(a) for a in self.args)})" - - -class TensorTypeFn(TensorExpression): - """Application of a type conversion function.""" - - def __init__(self, type_fn: Optional["TypeFn"], - operand_def: Optional["OperandDef"], type_var: TypeVar, - arg: TensorExpression): - if bool(type_fn) + bool(operand_def) != 1: - raise ValueError("Either 'type_fn' or 'operand_def' must be specified") - self.type_fn = type_fn + def __init__(self, kind: "FunctionKind", name: Optional[str], + operand_def: Optional["OperandDef"], type_var: Optional[TypeVar], + args: Sequence[TensorExpression]): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind self.operand_def = operand_def self.type_var = type_var - self.arg = arg + self.args = args def to_scalar_expression(self) -> ScalarExpression: if self.operand_def: - assert self.operand_def.name, "TypeFnAttr not registered with an op" - fn_name = self.type_fn.fn_name if self.type_fn else None + assert self.operand_def.name, "TensorFn not registered with an op" attr_name = self.operand_def.name if self.operand_def else None - return ScalarTypeFn(fn_name, attr_name, self.type_var, - self.arg.to_scalar_expression()).expr() + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): super().visit_tensor_exprs(callback) - self.arg.visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) def __repr__(self): - return (f"{repr(self.type_fn)}[{repr(self.operand_def)}]" - f"({self.type_var}, {self.arg})") + name = self.operand_def.name if self.operand_def else self.name + return (f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})") class TensorReduceFn(TensorExpression): @@ -194,7 +175,7 @@ def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): self.reduce_use = reduce_use self.lhs = None # type: Optional[TensorUse] - self.args = tuple(args) + self.args = args def to_scalar_expression(self) -> ScalarExpression: if self.lhs is None: @@ -202,7 +183,8 @@ def to_scalar_expression(self) -> ScalarExpression: f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarArithFn(self.reduce_use.arith_fn.fn_name, *full_args).expr() + return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None, + None, full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -259,6 +241,11 @@ def __repr__(self): ############################################################################### +class FunctionKind(Enum): + ARITH = 0 + TYPE = 1 + + class TypeFnType: """Type conversion function. @@ -269,8 +256,8 @@ class TypeFnType: def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TypeFnType": - return TensorTypeFn(self, None, type_var, arg) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) def __repr__(self): return f"{self.fn_name}" @@ -301,8 +288,8 @@ class ArithFnType: def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, *args) -> "TensorArithFn": - return TensorArithFn(self, args) + def __call__(self, *args) -> "TensorFn": + return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args) def __repr__(self): return f"{self.fn_name}" @@ -562,8 +549,8 @@ def __init__(self, default: "TypeFnType"): self.operand_def = OperandDef( OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorTypeFn: - return TensorTypeFn(None, self.operand_def, type_var, arg) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) ############################################################################### diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index fc8c13bfe..07050f56f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -270,19 +270,19 @@ def expression(self, expr: ScalarExpression) -> Value: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.arith_fn: - fn = self._get_function(f"_arithfn_{expr.arith_fn.fn_name}") + elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH: + fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}") operand_values = [ - self.expression(operand) for operand in expr.arith_fn.operands + self.expression(operand) for operand in expr.scalar_fn.operands ] return fn(*operand_values) - elif expr.type_fn: - fn_name = expr.type_fn.fn_name - if expr.type_fn.attr_name: - fn_name = self.type_fn_attr_mapping[expr.type_fn.attr_name] + elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE: + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] fn = self._get_function(f"_typefn_{fn_name}") - operand = self.expression(expr.type_fn.operand) - return fn(expr.type_fn.type_var.name, operand) + operand_value = self.expression(expr.scalar_fn.operands[0]) + return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") def yield_outputs(self, *output_names: str): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index af21b40cf..aa894dc10 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -15,13 +15,13 @@ from typing import Optional, Sequence -from .yaml_helper import * +from .comprehension import * from .types import * +from .yaml_helper import * __all__ = [ "ScalarAssign", - "ScalarArithFn", - "ScalarTypeFn", + "ScalarFn", "ScalarArg", "ScalarConst", "ScalarIndex", @@ -29,36 +29,27 @@ ] -class ScalarArithFn: - """A type of ScalarExpression that applies an arithmetic function.""" - - def __init__(self, fn_name: str, *operands: "ScalarExpression"): - self.fn_name = fn_name - self.operands = operands - - def expr(self) -> "ScalarExpression": - return ScalarExpression(arith_fn=self) - - def __repr__(self): - return f"ScalarArithFn<{self.fn_name}>({', '.join(self.operands)})" - - -class ScalarTypeFn: - """A type of ScalarExpression that applies a type conversion function.""" +class ScalarFn: + """A type of ScalarExpression that applies a function.""" - def __init__(self, fn_name: Optional[str], attr_name: Optional[str], - type_var: TypeVar, operand: "ScalarExpression"): + def __init__(self, kind: "FunctionKind", fn_name: Optional[str], + attr_name: Optional[str], type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"]): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind self.fn_name = fn_name self.attr_name = attr_name self.type_var = type_var - self.operand = operand + self.operands = operands def expr(self) -> "ScalarExpression": - return ScalarExpression(type_fn=self) + return ScalarExpression(scalar_fn=self) def __repr__(self): - return (f"ScalarTypeFn<{self.fn_name}[{self.attr_name}]>" - f"({self.type_var}, {self.operand})") + name = self.fn_name if self.fn_name else self.attr_name + return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])") class ScalarArg: @@ -104,51 +95,38 @@ class ScalarExpression(YAMLObject): """An expression on scalar values. Can be one of: - - ScalarArithFn - - ScalarTypeFn + - ScalarFn - ScalarArg - ScalarConst - ScalarIndex - - ScalarSymbolicCast """ yaml_tag = "!ScalarExpression" def __init__(self, - arith_fn: Optional[ScalarArithFn] = None, - type_fn: Optional[ScalarTypeFn] = None, + scalar_fn: Optional[ScalarFn] = None, scalar_arg: Optional[ScalarArg] = None, scalar_const: Optional[ScalarConst] = None, scalar_index: Optional[ScalarIndex] = None): - if (bool(arith_fn) + bool(type_fn) + bool(scalar_arg) + bool(scalar_const) + + if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index)) != 1: - raise ValueError("One of 'arith_fn', 'type_fn', 'scalar_arg', " - "'scalar_const', 'scalar_index', must be specified") - self.arith_fn = arith_fn - self.type_fn = type_fn + raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified") + self.scalar_fn = scalar_fn self.scalar_arg = scalar_arg self.scalar_const = scalar_const self.scalar_index = scalar_index def to_yaml_custom_dict(self): - if self.arith_fn: - return dict( - arith_fn=dict( - fn_name=self.arith_fn.fn_name, - operands=list(self.arith_fn.operands), - )) - if self.type_fn: - # Note that even though operands must be arity 1, we write it the - # same way as for apply because it allows handling code to be more - # generic vs having a special form. - type_fn_dict = dict( - type_var=self.type_fn.type_var.name, - operands=[self.type_fn.operand], - ) - if self.type_fn.fn_name: - type_fn_dict["fn_name"] = self.type_fn.fn_name - if self.type_fn.attr_name: - type_fn_dict["attr_name"] = self.type_fn.attr_name - return dict(type_fn=type_fn_dict) + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) elif self.scalar_arg: return dict(scalar_arg=self.scalar_arg.arg) elif self.scalar_const: From 018f4d555a693380574dcf6933ea99e83b373279 Mon Sep 17 00:00:00 2001 From: gysit Date: Fri, 25 Feb 2022 15:11:52 +0000 Subject: [PATCH 244/915] [mlir][OpDSL] Split arithmetic functions. Split arithmetic function into unary and binary functions. The revision prepares the introduction of unary and binary function attributes that work similar to type function attributes. Depends On D120108 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120109 --- .../linalg/opdsl/lang/comprehension.py | 143 ++++++++++-------- .../dialects/linalg/opdsl/lang/emitter.py | 58 +++---- .../linalg/opdsl/ops/core_named_ops.py | 2 +- 3 files changed, 115 insertions(+), 88 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index d26aa0770..ef2ef3037 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -77,13 +77,13 @@ def visit_scalar_def(expr: "TensorExpression"): self.visit_tensor_exprs(visit_scalar_def) def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return ArithFn.add(self, rhs) + return BinaryFn.add(self, rhs) def __mul__(self, rhs) -> "TensorExpression": - return ArithFn.mul(self, rhs) + return BinaryFn.mul(self, rhs) def __sub__(self, rhs) -> "TensorExpression": - return ArithFn.sub(self, rhs) + return BinaryFn.sub(self, rhs) def __hash__(self): return hash(id(self)) @@ -126,7 +126,7 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: return rhs_dims - lhs_dims def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(ArithFn.add, *self._compute_reduce_dims(rhs))(rhs) + return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): return (f"{self.operand_def.name}" @@ -183,8 +183,8 @@ def to_scalar_expression(self) -> ScalarExpression: f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(FunctionKind.ARITH, self.reduce_use.arith_fn.fn_name, None, - None, full_args).expr() + return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name, + None, None, full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -242,61 +242,54 @@ def __repr__(self): class FunctionKind(Enum): - ARITH = 0 - TYPE = 1 + UNARY = 0 + BINARY = 1 + TYPE = 2 -class TypeFnType: - """Type conversion function. +class UnaryFnType: + """Unary function. - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. + A unary function takes one tensor expression and returns the + function evaluation result. """ def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + def __call__(self, exp: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp]) def __repr__(self): return f"{self.fn_name}" -class TypeFn: - """Type conversion function namespace. - - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned - (`cast_unsigned`) values. - - Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast = TypeFnType("cast") - cast_unsigned = TypeFnType("cast_unsigned") +class UnaryFn: + """Unary function namespace.""" + exp = UnaryFnType("exp") + log = UnaryFnType("log") -class ArithFnType: - """Arithmetic function. +class BinaryFnType: + """Binary function. - An arithmetic function takes one ore more tensor expressions and returns the + A binary function takes two tensor expressions and returns the function evaluation result. """ def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, *args) -> "TensorFn": - return TensorFn(FunctionKind.ARITH, self.fn_name, None, None, args) + def __call__(self, arg0: TensorExpression, + arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) def __repr__(self): return f"{self.fn_name}" -class ArithFn: - """Arithmetic function namespace. +class BinaryFn: + """Binary function namespace. As the integer types are signless, signedness is implement by different functions that treat integers as signed or unsigned values. @@ -305,15 +298,45 @@ class ArithFn: - max -> `arith.MaxSIOp` - max_unsinged -> `arith.MaxUIOp` """ - add = ArithFnType("add") - exp = ArithFnType("exp") - log = ArithFnType("log") - mul = ArithFnType("mul") - max = ArithFnType("max") - min = ArithFnType("min") - sub = ArithFnType("sub") - max_unsigned = ArithFnType("max_unsigned") - min_unsigned = ArithFnType("min_unsigned") + add = BinaryFnType("add") + mul = BinaryFnType("mul") + max = BinaryFnType("max") + min = BinaryFnType("min") + sub = BinaryFnType("sub") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") + + +class TypeFnType: + """Type conversion function. + + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + + def __repr__(self): + return f"{self.fn_name}" + + +class TypeFn: + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast`) or unsigned + (`cast_unsigned`) values. + + Examples: + - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ + cast = TypeFnType("cast") + cast_unsigned = TypeFnType("cast_unsigned") class ReduceFnUse: @@ -322,43 +345,43 @@ class ReduceFnUse: A reduction use specifies the reduction function and dimensions. """ - def __init__(self, arith_fn: ArithFnType, *reduce_dims: DimDef): - self.arith_fn = arith_fn + def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef): + self.binary_fn = binary_fn self.reduce_dims = reduce_dims - def __call__(self, *args: TensorExpression): + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}" + return (f"reduce_{self.binary_fn.fn_name}" f"({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFnType: """Reduction function. - An arithmetic function that reduces its RHS into its LHS. + A binary function that reduces its RHS into its LHS. """ - def __init__(self, arith_fn: ArithFnType): - if not isinstance(arith_fn, ArithFnType): - raise ValueError(f"Reduce expected a ArithFnType but got {arith_fn}") - self.arith_fn = arith_fn + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.arith_fn, *reduce_dims) + return ReduceFnUse(self.binary_fn, *reduce_dims) def __repr__(self): - return (f"reduce_{self.arith_fn.fn_name}") + return (f"reduce_{self.binary_fn.fn_name}") class ReduceFn: - add = ReduceFnType(ArithFn.add) - mul = ReduceFnType(ArithFn.mul) - max = ReduceFnType(ArithFn.max) - min = ReduceFnType(ArithFn.min) - max_unsigned = ReduceFnType(ArithFn.max_unsigned) - min_unsigned = ReduceFnType(ArithFn.min_unsigned) + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max = ReduceFnType(BinaryFn.max) + min = ReduceFnType(BinaryFn.min) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) ############################################################################### diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 07050f56f..df4ab2249 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -270,17 +270,19 @@ def expression(self, expr: ScalarExpression) -> Value: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.ARITH: - fn = self._get_function(f"_arithfn_{expr.scalar_fn.fn_name}") + elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE: + kind = expr.scalar_fn.kind.name.lower() + fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_fn.operands ] return fn(*operand_values) - elif expr.scalar_fn and expr.scalar_fn.kind == FunctionKind.TYPE: + elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE: + kind = expr.scalar_fn.kind.name.lower() fn_name = expr.scalar_fn.fn_name if expr.scalar_fn.attr_name: fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_typefn_{fn_name}") + fn = self._get_function(f"_{kind}_{fn_name}") operand_value = self.expression(expr.scalar_fn.operands[0]) return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") @@ -356,70 +358,72 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value, raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _typefn_cast(self, type_var_name: str, operand: Value) -> Value: + def _type_cast(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, False) - def _typefn_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, True) - def _arithfn_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'add' operand: {lhs}") - - def _arithfn_exp(self, x: Value) -> Value: + def _unary_exp(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.ExpOp(x).result raise NotImplementedError("Unsupported 'exp' operand: {x}") - def _arithfn_log(self, x: Value) -> Value: + def _unary_log(self, x: Value) -> Value: if _is_floating_point_type(x.type): return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") - def _arithfn_sub(self, lhs: Value, rhs: Value) -> Value: + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'sub' operand: {lhs}") + raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") - def _arithfn_mul(self, lhs: Value, rhs: Value) -> Value: + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'mul' operand: {lhs}") + raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") - def _arithfn_max(self, lhs: Value, rhs: Value) -> Value: + def _binary_max(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'max' operand: {lhs}") + raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") - def _arithfn_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}") + raise NotImplementedError( + "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") - def _arithfn_min(self, lhs: Value, rhs: Value) -> Value: + def _binary_min(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'min' operand: {lhs}") + raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") - def _arithfn_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinUIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}") + raise NotImplementedError( + "Unsupported 'min_unsigned' operands: {lhs}, {rhs}") def _infer_structured_outs( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index db63a0705..340f4db44 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -677,4 +677,4 @@ def soft_plus_2d( """ domain(D.m, D.n) O[D.m, D.n] = \ - ArithFn.log(TypeFn.cast(U, const(1.0)) + ArithFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n]))) From 71102530a2170d5c775bf8e688006d2081019103 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 1 Mar 2022 07:40:06 +0000 Subject: [PATCH 245/915] [mlir][OpDSL] Add arithmetic function attributes. The revision extends OpDSL with unary and binary function attributes. A function attribute, makes the operations used in the body of a structured operation configurable. For example, a pooling operation may take an aggregation function attribute that specifies if the op shall implement a min or a max pooling. The goal of this revision is to define less and more flexible operations. We may thus for example define an element wise op: ``` linalg.elem(lhs, rhs, outs=[out], op=BinaryFn.mul) ``` If the op argument is not set the default operation is used. Depends On D120109 Reviewed By: nicolasvasilache, aartbik Differential Revision: https://reviews.llvm.org/D120110 --- .../linalg/opdsl/lang/comprehension.py | 80 +++++++++++++++--- .../mlir/dialects/linalg/opdsl/lang/config.py | 4 +- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 3 +- .../dialects/linalg/opdsl/lang/emitter.py | 82 +++++++++++-------- .../linalg/opdsl/ops/core_named_ops.py | 29 +++++++ 5 files changed, 149 insertions(+), 49 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index ef2ef3037..f6bf0ff9a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -126,7 +126,7 @@ def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: return rhs_dims - lhs_dims def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(BinaryFn.add, *self._compute_reduce_dims(rhs))(rhs) + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) def __repr__(self): return (f"{self.operand_def.name}" @@ -183,8 +183,14 @@ def to_scalar_expression(self) -> ScalarExpression: f"bound to its lhs: {self}") full_args = [self.lhs.to_scalar_expression() ] + [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(FunctionKind.BINARY, self.reduce_use.binary_fn.fn_name, - None, None, full_args).expr() + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, + full_args).expr() def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): for arg in self.args: @@ -257,8 +263,8 @@ class UnaryFnType: def __init__(self, fn_name: str): self.fn_name = fn_name - def __call__(self, exp: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [exp]) + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) def __repr__(self): return f"{self.fn_name}" @@ -345,16 +351,21 @@ class ReduceFnUse: A reduction use specifies the reduction function and dimensions. """ - def __init__(self, binary_fn: BinaryFnType, *reduce_dims: DimDef): + def __init__(self, binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") self.binary_fn = binary_fn + self.binary_attr = binary_attr self.reduce_dims = reduce_dims def __call__(self, *args: TensorExpression) -> "TensorReduceFn": return TensorReduceFn(self, args) def __repr__(self): - return (f"reduce_{self.binary_fn.fn_name}" - f"({', '.join(repr(d) for d in self.reduce_dims)})") + fn = self.binary_fn if self.binary_fn else self.binary_attr + return ( + f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})") class ReduceFnType: @@ -369,10 +380,10 @@ def __init__(self, binary_fn: BinaryFnType): self.binary_fn = binary_fn def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.binary_fn, *reduce_dims) + return ReduceFnUse(self.binary_fn, None, *reduce_dims) def __repr__(self): - return (f"reduce_{self.binary_fn.fn_name}") + return f"reduce_{repr(self.binary_fn)}" class ReduceFn: @@ -394,7 +405,9 @@ class OperandKind(Enum): SCALAR = 1 OUTPUT_TENSOR = 2 INDEX_ATTR = 3 - TYPE_FN_ATTR = 4 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TYPE_FN_ATTR = 6 class OperandDef: @@ -441,6 +454,8 @@ def is_tensor(self) -> bool: def is_attribute(self) -> bool: return (self.kind == OperandKind.INDEX_ATTR or + self.kind == OperandKind.UNARY_FN_ATTR or + self.kind == OperandKind.BINARY_FN_ATTR or self.kind == OperandKind.TYPE_FN_ATTR) def __hash__(self): @@ -557,6 +572,49 @@ def __init__(self, *sizes: SymbolDef, default: Sequence[int]): OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) +class UnaryFnAttrDef: + """Unary function attribute definition. + + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + + +class BinaryFnAttrDef: + """Binary function attribute definition. + + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}") + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) + + def __call__(self, arg0: TensorExpression, + arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, + [arg0, arg1]) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + class TypeFnAttrDef: """Type conversion function attribute definition. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 12b168de1..ed30b8e5f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -309,8 +309,8 @@ def get_type(symbolic_name, position): def add_operand(self, operand_def: OperandDef): if operand_def in self.operands: return - if (operand_def.kind == OperandKind.SCALAR or - operand_def.kind == OperandKind.TYPE_FN_ATTR): + if not (operand_def.is_tensor() or + operand_def.kind == OperandKind.INDEX_ATTR): self.operands[operand_def] = OperandDefConfig(operand_def) return with self.context: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 99ce71366..bd9042ac0 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -130,7 +130,8 @@ def linalg_structured_op(dsl_func=None, for param_name, param in sig.parameters.items(): param_default = param.default if isinstance(param_default, - (TensorDef, ScalarDef, IndexAttrDef, TypeFnAttrDef)): + (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef, + BinaryFnAttrDef, TypeFnAttrDef)): op_def.add_operand(param_name, param_default.operand_def) else: raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index df4ab2249..79fc3f5a2 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -41,7 +41,7 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, all_arg_defs = op_config.ordered_operands in_arg_defs = [ d for d in all_arg_defs - if d.kind == OperandKind.SCALAR or d.kind == OperandKind.INPUT_TENSOR + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] ] out_arg_defs = [ d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR @@ -49,8 +49,11 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, index_attr_arg_defs = [ d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR ] - type_fn_attr_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.TYPE_FN_ATTR + fn_attr_arg_defs = [ + d for d in all_arg_defs if d.kind in [ + OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR + ] ] # Verify outs is a sequence or a list of results. @@ -135,28 +138,38 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, array = np.array(index_attr_vals, dtype=np.int64) index_attrs[index_attr.name] = DenseElementsAttr.get(array) - # Compute the type function attribute mapping. - type_fn_attr_mapping = {} - for type_fn_attr in type_fn_attr_arg_defs: - attr_val = type_fn_attr.operand_def.default_fn - if type_fn_attr.name in attrs: - type_fn = attrs.get(type_fn_attr.name) - if not isinstance(type_fn, TypeFnType): - raise ValueError(f"Attribute {type_fn_attr.name} needs to be of type " - f"TypeFnType but got {type(attr_val)}") - attr_val = type_fn.fn_name - assert attr_val, "Type function attribute has no value" - type_fn_attr_mapping[type_fn_attr.name] = attr_val + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}") + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}") + else: + if not isinstance(fn, TypeFnType): + raise ValueError(f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}") + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, - type_fn_attr_mapping, block_arg_types) + fn_attr_mapping, block_arg_types) def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -193,7 +206,7 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, block_arg_mapping = dict(zip(block_arg_names, block.arguments)) with InsertionPoint(block): body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - type_fn_attr_mapping) + fn_attr_mapping) for assignment in op_config.assignments: body_builder.assign(assignment) body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) @@ -208,7 +221,7 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, op_class_name: str, *ins: Value, outs: ValueList, **attrs: Sequence[int]): all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, type_fn_attr_mapping, \ + indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ block_arg_types = \ prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) @@ -225,10 +238,12 @@ def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, for name, value in index_attrs.items(): named_op.operation.attributes[name] = value - # Set the type function attributes. - for name, value in type_fn_attr_mapping.items(): + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] named_op.operation.attributes[name] = Attribute.parse( - f"#linalg.type_fn<{value}>") + f"#linalg.{enum_name}<{fn_name}>") linalg.fill_builtin_region(named_op.operation) @@ -242,11 +257,11 @@ class _BodyBuilder: """Constructs a structured op body by evaluating assignments.""" def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], - type_fn_attr_mapping: Dict[str, str]): + block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str, + str]): self.type_mapping = type_mapping self.block_arg_mapping = block_arg_mapping - self.type_fn_attr_mapping = type_fn_attr_mapping + self.fn_attr_mapping = fn_attr_mapping self.yield_mapping = dict() # type: Dict[str, Value] def assign(self, assignment: ScalarAssign): @@ -270,21 +285,18 @@ def expression(self, expr: ScalarExpression) -> Value: dim_attr = IntegerAttr.get( IntegerType.get_signless(64), expr.scalar_index.dim) return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn and expr.scalar_fn.kind is not FunctionKind.TYPE: + elif expr.scalar_fn: kind = expr.scalar_fn.kind.name.lower() - fn = self._get_function(f"_{kind}_{expr.scalar_fn.fn_name}") + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") operand_values = [ self.expression(operand) for operand in expr.scalar_fn.operands ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [expr.scalar_fn.type_var.name] + operand_values return fn(*operand_values) - elif expr.scalar_fn and expr.scalar_fn.kind is FunctionKind.TYPE: - kind = expr.scalar_fn.kind.name.lower() - fn_name = expr.scalar_fn.fn_name - if expr.scalar_fn.attr_name: - fn_name = self.type_fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_{kind}_{fn_name}") - operand_value = self.expression(expr.scalar_fn.operands[0]) - return fn(expr.scalar_fn.type_var.name, operand_value) raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") def yield_outputs(self, *output_names: str): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 340f4db44..b7a827bf6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -6,6 +6,35 @@ Batch = S.Batch +@linalg_structured_op +def elemwise_unary( + I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast)): + """Applies the unary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, I[None])) + + +@linalg_structured_op +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast)): + """Applies the binary function fun elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From 6dc59456ae9cc67ca8129adfab52e95c3284ecc6 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 1 Mar 2022 08:10:51 +0000 Subject: [PATCH 246/915] [mlir][OpDSL] Rename function to make signedness explicit (NFC). The revision renames the following OpDSL functions: ``` TypeFn.cast -> TypeFn.cast_signed BinaryFn.min -> BinaryFn.min_signed BinaryFn.max -> BinaryFn.max_signed ``` The corresponding enum values on the C++ side are renamed accordingly: ``` #linalg.type_fn -> #linalg.type_fn #linalg.binary_fn -> #linalg.binary_fn #linalg.binary_fn -> #linalg.binary_fn ``` Depends On D120110 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D120562 --- .../linalg/opdsl/lang/comprehension.py | 16 +- .../dialects/linalg/opdsl/lang/emitter.py | 6 +- .../linalg/opdsl/ops/core_named_ops.py | 144 +++++++++--------- 3 files changed, 86 insertions(+), 80 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index f6bf0ff9a..7de0a76e8 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -305,10 +305,10 @@ class BinaryFn: - max_unsinged -> `arith.MaxUIOp` """ add = BinaryFnType("add") - mul = BinaryFnType("mul") - max = BinaryFnType("max") - min = BinaryFnType("min") sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") max_unsigned = BinaryFnType("max_unsigned") min_unsigned = BinaryFnType("min_unsigned") @@ -334,14 +334,14 @@ class TypeFn: """Type conversion function namespace. As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast`) or unsigned + functions that treat integers as signed (`cast_signed`) or unsigned (`cast_unsigned`) values. Examples: - - cast(I32 -> I64) -> `arith.ExtSIOp` + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` """ - cast = TypeFnType("cast") + cast_signed = TypeFnType("cast_signed") cast_unsigned = TypeFnType("cast_unsigned") @@ -389,8 +389,8 @@ def __repr__(self): class ReduceFn: add = ReduceFnType(BinaryFn.add) mul = ReduceFnType(BinaryFn.mul) - max = ReduceFnType(BinaryFn.max) - min = ReduceFnType(BinaryFn.min) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) max_unsigned = ReduceFnType(BinaryFn.max_unsigned) min_unsigned = ReduceFnType(BinaryFn.min_unsigned) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 79fc3f5a2..453a3e80c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -370,7 +370,7 @@ def _cast_to_floating_point(self, to_type: Type, operand: Value, raise ValueError(f"Unable to cast body expression from {operand_type} to " f"{to_type}") - def _type_cast(self, type_var_name: str, operand: Value) -> Value: + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: return self._cast(type_var_name, operand, False) def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: @@ -407,7 +407,7 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value: return arith.MulIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") - def _binary_max(self, lhs: Value, rhs: Value) -> Value: + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MaxFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): @@ -422,7 +422,7 @@ def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: raise NotImplementedError( "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") - def _binary_min(self, lhs: Value, rhs: Value) -> Value: + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.MinFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b7a827bf6..0ef40613a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -11,7 +11,7 @@ def elemwise_unary( I=TensorDef(T1), O=TensorDef(U, output=True), fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the unary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -26,7 +26,7 @@ def elemwise_binary( rhs=TensorDef(T2), O=TensorDef(U, output=True), fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the binary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -40,7 +40,7 @@ def matmul( A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.K, S.N), C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast)): + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -82,8 +82,9 @@ def quantized_matmul( matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += (TypeFn.cast(U, A[D.m, D.k]) - TypeFn.cast(U, AZp)) * ( - TypeFn.cast(U, B[D.k, D.n]) - TypeFn.cast(U, BZp)) + C[D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op @@ -103,8 +104,8 @@ def mmt4d( """ domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast( - TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast( + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed( TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @@ -121,7 +122,8 @@ def batch_matmul( domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) C[D.b, D.m, - D.n] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k, D.n]) + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) @linalg_structured_op @@ -139,9 +141,9 @@ def quantized_batch_matmul( matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, - D.n] += (TypeFn.cast(U, A[D.b, D.m, D.k]) - TypeFn.cast(U, AZp)) * ( - TypeFn.cast(U, B[D.b, D.k, D.n]) - TypeFn.cast(U, BZp)) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op @@ -156,7 +158,7 @@ def matvec( """ domain(D.m, D.n) implements(ContractionOpInterface) - x[D.m] += TypeFn.cast(U, A[D.m, D.n]) * TypeFn.cast(U, y[D.n]) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) @linalg_structured_op @@ -171,7 +173,7 @@ def vecmat( """ domain(D.n, D.m) implements(ContractionOpInterface) - x[D.n] += TypeFn.cast(U, y[D.m]) * TypeFn.cast(U, A[D.m, D.n]) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) @linalg_structured_op @@ -186,7 +188,8 @@ def batch_matvec( """ domain(D.b, D.m, D.k) implements(ContractionOpInterface) - C[D.b, D.m] += TypeFn.cast(U, A[D.b, D.m, D.k]) * TypeFn.cast(U, B[D.b, D.k]) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k]) @linalg_structured_op @@ -198,7 +201,7 @@ def dot( them to the same data type as the accumulator/output. """ implements(ContractionOpInterface) - C[None] += TypeFn.cast(U, A[D.m]) * TypeFn.cast(U, B[D.m]) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) @linalg_structured_op @@ -213,7 +216,8 @@ def conv_1d( """ implements(ConvolutionOpInterface) domain(D.ow, D.kw) - O[D.ow] += TypeFn.cast(U, I[D.ow + D.kw]) * TypeFn.cast(U, K[D.kw]) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed( + U, K[D.kw]) @linalg_structured_op @@ -228,8 +232,8 @@ def conv_2d( """ implements(ConvolutionOpInterface) domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += TypeFn.cast(U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast( - U, K[D.kh, D.kw]) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw]) @linalg_structured_op @@ -244,9 +248,9 @@ def conv_3d( """ implements(ConvolutionOpInterface) domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) - O[D.od, D.oh, - D.ow] += TypeFn.cast(U, I[D.od + D.kd, D.oh + D.kh, D.ow + - D.kw]) * TypeFn.cast(U, K[D.kd, D.kh, D.kw]) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed( + U, K[D.kd, D.kh, D.kw]) @linalg_structured_op @@ -264,8 +268,8 @@ def conv_1d_nwc_wcf( implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) O[D.n, D.ow, - D.f] += TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast(U, K[D.kw, D.c, D.f]) + D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, + D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) @linalg_structured_op @@ -287,9 +291,9 @@ def conv_2d_nhwc_hwcf( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) + D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op @@ -315,10 +319,11 @@ def conv_2d_nhwc_hwcf_q( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) O[D.n, D.oh, D.ow, - D.f] += (TypeFn.cast( + D.f] += (TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - TypeFn.cast(U, IZp)) * ( - TypeFn.cast(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast(U, KZp)) + TypeFn.cast_signed(U, IZp)) * ( + TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - + TypeFn.cast_signed(U, KZp)) @linalg_structured_op @@ -340,9 +345,9 @@ def conv_2d_nchw_fchw( """ implements(ConvolutionOpInterface) domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += TypeFn.cast( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast(U, K[D.f, D.c, D.kh, D.kw]) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op @@ -360,9 +365,9 @@ def conv_3d_ndhwc_dhwcf( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast( + D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( U, K[D.kd, D.kh, D.kw, D.c, D.f]) @@ -382,8 +387,8 @@ def depthwise_conv_1d_nwc_wc( implements(ConvolutionOpInterface) domain(D.n, D.ow, D.ic, D.kw) O[D.n, D.ow, D.ic] += \ - TypeFn.cast(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast(U, K[D.kw, D.ic]) + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + TypeFn.cast_signed(U, K[D.kw, D.ic]) @linalg_structured_op @@ -402,9 +407,9 @@ def depthwise_conv_2d_nhwc_hwc( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic]) + D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op @@ -424,11 +429,11 @@ def depthwise_conv_2d_nhwc_hwc_q( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.ic] += ((TypeFn.cast( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast(U, IZp)) * - (TypeFn.cast(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast(U, KZp))) + O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - + TypeFn.cast_signed(U, IZp)) * + (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - + TypeFn.cast_signed(U, KZp))) @linalg_structured_op @@ -446,9 +451,9 @@ def depthwise_conv_2d_nhwc_hwcm( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) + D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op @@ -469,10 +474,11 @@ def depthwise_conv_2d_nhwc_hwcm_q( implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) O[D.n, D.oh, D.ow, D.ic, - D.cm] += ((TypeFn.cast( + D.cm] += ((TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast(U, IZp)) * - (TypeFn.cast(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast(U, KZp))) + TypeFn.cast_signed(U, IZp)) * + (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - + TypeFn.cast_signed(U, KZp))) @linalg_structured_op @@ -490,7 +496,7 @@ def pooling_nhwc_sum( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -509,8 +515,8 @@ def pooling_nhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -549,8 +555,8 @@ def pooling_nchw_max( """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) @@ -570,8 +576,8 @@ def pooling_nhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min[D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -610,7 +616,7 @@ def pooling_ndhwc_sum( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) @@ -630,8 +636,8 @@ def pooling_ndhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max[D.kd, D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -651,8 +657,8 @@ def pooling_ndhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min[D.kd, D.kh, D.kw]( - TypeFn.cast( + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @@ -665,7 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ - O[None] = TypeFn.cast(U, value) + O[None] = TypeFn.cast_signed(U, value) @linalg_structured_op @@ -685,15 +691,15 @@ def fill_rng_2d( the range of the generated random numbers. """ domain(D.m, D.n) - multiplier = TypeFn.cast(I32, const(1103515245)) - increment = TypeFn.cast(I32, const(12345)) - rand1 = (TypeFn.cast(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (TypeFn.cast(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = TypeFn.cast(F64, const(2.3283064e-10)) - offset = TypeFn.cast(F64, const(2147483647)) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) scaling = (max - min) * inv_range - O[D.m, D.n] = TypeFn.cast(T, - (offset + TypeFn.cast(F64, rand2)) * scaling + min) + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) @linalg_structured_op @@ -706,4 +712,4 @@ def soft_plus_2d( """ domain(D.m, D.n) O[D.m, D.n] = \ - UnaryFn.log(TypeFn.cast(U, const(1.0)) + UnaryFn.exp(TypeFn.cast(U, I[D.m, D.n]))) + UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n]))) From b883512cbd1747e257d6435bc28aa06f9c03b580 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sat, 26 Feb 2022 14:49:54 -0800 Subject: [PATCH 247/915] [mlir] Rename the Standard dialect to the Func dialect The last remaining operations in the standard dialect all revolve around FuncOp/function related constructs. This patch simply handles the initial renaming (which by itself is already huge), but there are a large number of cleanups unlocked/necessary afterwards: * Removing a bunch of unnecessary dependencies on Func * Cleaning up the From/ToStandard conversion passes * Preparing for the move of FuncOp to the Func dialect See the discussion at https://discourse.llvm.org/t/standard-dialect-the-final-chapter/6061 Differential Revision: https://reviews.llvm.org/D120624 --- .../mlir-c/Dialect/{Standard.h => Func.h} | 12 +++++------ mlir/include/mlir-c/Registration.h | 2 +- mlir/lib/CAPI/Dialect/CMakeLists.txt | 6 +++--- .../CAPI/Dialect/{Standard.cpp => Func.cpp} | 8 ++++---- mlir/python/CMakeLists.txt | 20 +++++++++---------- .../dialects/{StandardOps.td => FuncOps.td} | 12 +++++------ mlir/python/mlir/dialects/_builtin_ops_ext.py | 8 ++++---- .../{_std_ops_ext.py => _func_ops_ext.py} | 4 ++-- mlir/python/mlir/dialects/{std.py => func.py} | 2 +- .../dialects/linalg/opdsl/lang/emitter.py | 2 +- 10 files changed, 38 insertions(+), 38 deletions(-) rename mlir/include/mlir-c/Dialect/{Standard.h => Func.h} (72%) rename mlir/lib/CAPI/Dialect/{Standard.cpp => Func.cpp} (60%) rename mlir/python/mlir/dialects/{StandardOps.td => FuncOps.td} (70%) rename mlir/python/mlir/dialects/{_std_ops_ext.py => _func_ops_ext.py} (97%) rename mlir/python/mlir/dialects/{std.py => func.py} (87%) diff --git a/mlir/include/mlir-c/Dialect/Standard.h b/mlir/include/mlir-c/Dialect/Func.h similarity index 72% rename from mlir/include/mlir-c/Dialect/Standard.h rename to mlir/include/mlir-c/Dialect/Func.h index 200962177..4bdac4268 100644 --- a/mlir/include/mlir-c/Dialect/Standard.h +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -1,4 +1,4 @@ -//===-- mlir-c/Dialect/Standard.h - C API for Standard dialect ----*- C -*-===// +//===-- mlir-c/Dialect/Func.h - C API for Func dialect ------------*- C -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM // Exceptions. @@ -8,15 +8,15 @@ //===----------------------------------------------------------------------===// // // This header declares the C interface for registering and accessing the -// Standard dialect. A dialect should be registered with a context to make it +// Func dialect. A dialect should be registered with a context to make it // available to users of the context. These users must load the dialect // before using any of its attributes, operations or types. Parser and pass // manager can load registered dialects automatically. // //===----------------------------------------------------------------------===// -#ifndef MLIR_C_DIALECT_STANDARD_H -#define MLIR_C_DIALECT_STANDARD_H +#ifndef MLIR_C_DIALECT_FUNC_H +#define MLIR_C_DIALECT_FUNC_H #include "mlir-c/Registration.h" @@ -24,10 +24,10 @@ extern "C" { #endif -MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Standard, std); +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Func, func); #ifdef __cplusplus } #endif -#endif // MLIR_C_DIALECT_STANDARD_H +#endif // MLIR_C_DIALECT_FUNC_H diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h index 442449626..ab37866d9 100644 --- a/mlir/include/mlir-c/Registration.h +++ b/mlir/include/mlir-c/Registration.h @@ -20,7 +20,7 @@ extern "C" { // Dialect registration declarations. // Registration entry-points for each dialect are declared using the common // MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect -// API name (i.e. "Standard", "Tensor", "Linalg") and namespace (i.e. "std", +// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", // "tensor", "linalg"). The following declarations are produced: // // /// Gets the above hook methods in struct form for a dialect by namespace. diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index f66f7b0b8..9822f059e 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -80,13 +80,13 @@ add_mlir_upstream_c_api_library(MLIRCAPISparseTensor MLIRSparseTensorTransforms ) -add_mlir_upstream_c_api_library(MLIRCAPIStandard - Standard.cpp +add_mlir_upstream_c_api_library(MLIRCAPIFunc + Func.cpp PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRStandard + MLIRFunc ) add_mlir_upstream_c_api_library(MLIRCAPITensor diff --git a/mlir/lib/CAPI/Dialect/Standard.cpp b/mlir/lib/CAPI/Dialect/Func.cpp similarity index 60% rename from mlir/lib/CAPI/Dialect/Standard.cpp rename to mlir/lib/CAPI/Dialect/Func.cpp index 57083a8a2..a49d2f425 100644 --- a/mlir/lib/CAPI/Dialect/Standard.cpp +++ b/mlir/lib/CAPI/Dialect/Func.cpp @@ -1,4 +1,4 @@ -//===- Standard.cpp - C Interface for Standard dialect --------------------===// +//===- Func.cpp - C Interface for Func dialect ----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Dialect/Standard.h" +#include "mlir-c/Dialect/Func.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Standard, std, mlir::StandardOpsDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Func, func, mlir::func::FuncDialect) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index bf379e9b2..1477ffea7 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -80,6 +80,15 @@ declare_mlir_dialect_python_bindings( dialects/cf.py DIALECT_NAME cf) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/FuncOps.td + SOURCES + dialects/func.py + dialects/_func_ops_ext.py + DIALECT_NAME func) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -164,15 +173,6 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/sparse_tensor.py DIALECT_NAME sparse_tensor) -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT MLIRPythonSources.Dialects - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - TD_FILE dialects/StandardOps.td - SOURCES - dialects/std.py - dialects/_std_ops_ext.py - DIALECT_NAME std) - declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -232,7 +232,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects - MLIRCAPIStandard + MLIRCAPIFunc ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind diff --git a/mlir/python/mlir/dialects/StandardOps.td b/mlir/python/mlir/dialects/FuncOps.td similarity index 70% rename from mlir/python/mlir/dialects/StandardOps.td rename to mlir/python/mlir/dialects/FuncOps.td index 5b7caabc2..1728091f4 100644 --- a/mlir/python/mlir/dialects/StandardOps.td +++ b/mlir/python/mlir/dialects/FuncOps.td @@ -1,4 +1,4 @@ -//===-- StandardOps.td - Entry point for StandardOps bind --*- tablegen -*-===// +//===-- FuncOps.td - Entry point for Func bind -------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// // -// This is the main file from which the Python bindings for the Standard -// dialect are generated. +// This is the main file from which the Python bindings for the Func dialect +// are generated. // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_STANDARD_OPS -#define PYTHON_BINDINGS_STANDARD_OPS +#ifndef PYTHON_BINDINGS_FUNC +#define PYTHON_BINDINGS_FUNC include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/StandardOps/IR/Ops.td" +include "mlir/Dialect/Func/IR/FuncOps.td" #endif diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index 78f8c95c4..a3a147469 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -162,7 +162,7 @@ def from_py_func(FuncOp, """ def decorator(f): - from . import std + from . import func # Introspect the callable for optional features. sig = inspect.signature(f) has_arg_func_op = False @@ -208,15 +208,15 @@ def decorator(f): return_values = return_values.results else: return_values = list(return_values) - std.ReturnOp(return_values) + func.ReturnOp(return_values) # Recompute the function type. return_types = [v.type for v in return_values] function_type = FunctionType.get(inputs=inputs, results=return_types) func_op.attributes["type"] = TypeAttr.get(function_type) def emit_call_op(*call_args): - call_op = std.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), - call_args) + call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), + call_args) if return_types is None: return None elif len(return_types) == 1: diff --git a/mlir/python/mlir/dialects/_std_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py similarity index 97% rename from mlir/python/mlir/dialects/_std_ops_ext.py rename to mlir/python/mlir/dialects/_func_ops_ext.py index f4cb6186b..850562673 100644 --- a/mlir/python/mlir/dialects/_std_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -46,8 +46,8 @@ def __init__(self, For example f = builtin.FuncOp("foo", ...) - std.CallOp(f, [args]) - std.CallOp([result_types], "foo", [args]) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) In all cases, the location and insertion point may be specified as keyword arguments if not provided by the surrounding context managers. diff --git a/mlir/python/mlir/dialects/std.py b/mlir/python/mlir/dialects/func.py similarity index 87% rename from mlir/python/mlir/dialects/std.py rename to mlir/python/mlir/dialects/func.py index 8e55807a0..dc554c221 100644 --- a/mlir/python/mlir/dialects/std.py +++ b/mlir/python/mlir/dialects/func.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from ._std_ops_gen import * +from ._func_ops_gen import * diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 453a3e80c..ff5c405d7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -6,8 +6,8 @@ from .....ir import * +from .... import func from .... import linalg -from .... import std from .... import math from .... import arith from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values From 999d9b7c97358bae59ddea6ce8509ca41c1ed6f3 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 4 Mar 2022 12:53:22 -0800 Subject: [PATCH 248/915] [mlir][NFC] Move Parser.h to Parser/ There is no reason for this file to be at the top-level, and its current placement predates the Parser/ folder's existence. Differential Revision: https://reviews.llvm.org/D121024 --- mlir/lib/CAPI/IR/IR.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c067b202b..8efbbebdf 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -20,7 +20,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" -#include "mlir/Parser.h" +#include "mlir/Parser/Parser.h" #include "llvm/Support/Debug.h" #include From 193c743a236286d383bd8036b65a5e5b1167d9cc Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Tue, 8 Mar 2022 12:59:42 +0100 Subject: [PATCH 249/915] Update more `parseSourceString()` call sites. Change to non-deprecated function template (see D121075). Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D121102 --- mlir/lib/CAPI/IR/IR.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8efbbebdf..75ac93ffb 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -195,7 +195,7 @@ MlirModule mlirModuleCreateEmpty(MlirLocation location) { MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { OwningOpRef owning = - parseSourceString(unwrap(module), unwrap(context)); + parseSourceString(unwrap(module), unwrap(context)); if (!owning) return MlirModule{nullptr}; return MlirModule{owning.release().getOperation()}; From 9386959115070ffd366c4daf6b20a3f21a8abb2d Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 15:33:47 +0000 Subject: [PATCH 250/915] [mlir][linalg] Add a FillOpInterface. Add a FillOpInterface similar to the contraction and convolution op interfaces. The FillOpInterface is a preparation step to replace linalg.fill by its OpDSL version linalg.fill_tensor. The interface implements the `value()`, `output()`, and `result()` methods that by default are not available on linalg.fill_tensor. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120725 --- mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 1 + mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 1 + 2 files changed, 2 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 7de0a76e8..1de5449e2 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -686,6 +686,7 @@ def __init__(self, cpp_name: str): ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") ConvolutionOpInterface = OpInterfaceDef("LinalgConvolutionOpInterface") +FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") class OpMetadataDef(YAMLObject): diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 0ef40613a..7798d7f94 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -671,6 +671,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): accesses only and is thus rank polymorphic. Numeric casting is performed on the value operand, promoting it to the same data type as the output. """ + implements(FillOpInterface) O[None] = TypeFn.cast_signed(U, value) From 5494379fca28a5a266cd4fc99e3eb571031ca531 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 15:56:40 +0000 Subject: [PATCH 251/915] [mlir][OpDSL] Add support for adding canonicalization patterns. Extend OpDSL with a `defines` method that can set the `hasCanonicalizer` flag for an OpDSL operation. If the flag is set via `defines(Canonicalizer)` the operation needs to implement the `getCanonicalizationPatterns` method. The revision specifies the flag for linalg.fill_tensor and adds an empty `FillTensorOp::getCanonicalizationPatterns` implementation. This revision is a preparation step to replace linalg.fill by its OpDSL counterpart linalg.fill_tensor. The two are only functionally equivalent if both specify the same canonicalization patterns. The revision is thus a prerequisite for the linalg.fill replacement. Depends On D120725 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120726 --- .../linalg/opdsl/lang/comprehension.py | 13 ++++++++++++ .../mlir/dialects/linalg/opdsl/lang/dsl.py | 20 +++++++++++++------ .../linalg/opdsl/ops/core_named_ops.py | 1 + 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 1de5449e2..47083de62 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -689,6 +689,16 @@ def __init__(self, cpp_name: str): FillOpInterface = OpInterfaceDef("LinalgFillOpInterface") +class OpDefinitionDef: + """A method that an op implements.""" + + def __init__(self, def_name: str): + self.def_name = def_name + + +Canonicalizer = OpDefinitionDef("hasCanonicalizer") + + class OpMetadataDef(YAMLObject): """Metadata about the op (generally not behavior impacting).""" yaml_tag = "!LinalgOpMetadata" @@ -699,6 +709,7 @@ def __init__(self, name: str, cpp_class_name: Optional[str], self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name self.doc = doc self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] def to_yaml_custom_dict(self): d = dict( @@ -708,6 +719,8 @@ def to_yaml_custom_dict(self): ) if self.implements: d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] return d diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index bd9042ac0..45b8d5ccd 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -149,13 +149,21 @@ def linalg_structured_op(dsl_func=None, return DefinedOpCallable(op_name, op_def) +def domain(*dimensions: DimDef): + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) + + def implements(*interfaces: OpInterfaceDef): + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}") current_op_def().metadata.implements.extend(interfaces) -def domain(*dimensions: DimDef): - if current_op_def().domain: - raise ValueError(f"Expected only one set of domain dimensions per operator") - if any(not isinstance(dim, DimDef) for dim in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) +def defines(*definitions: OpDefinitionDef): + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}") + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7798d7f94..39934131c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -672,6 +672,7 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): the value operand, promoting it to the same data type as the output. """ implements(FillOpInterface) + defines(Canonicalizer) O[None] = TypeFn.cast_signed(U, value) From ed6c4a478498831bdacbd058e9f2d74489354288 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 17:06:50 +0000 Subject: [PATCH 252/915] [mlir][OpDSL] Simplify index and constant tests. Simplify tests that use `linalg.fill_rng_2d` to focus on testing the `const` and `index` functions. Additionally, cleanup emit_misc.py to use simpler test functions and fix an error message in config.py. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120734 --- mlir/python/mlir/dialects/linalg/opdsl/lang/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index ed30b8e5f..2a0da6829 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -265,8 +265,8 @@ def __init__(self, for index in collected_indices: if index.dim_def.dimname not in self.affine_state.all_dims: raise ValueError( - f"The dimension {index.dim.dimname} is not part of the iteration " - f"domain {self.affine_state.all_dims}") + f"The dimension {index.dim_def.dimname} is not part of the " + f"iteration domain {self.affine_state.all_dims}") index.resolve_dimension_name(self.affine_state) # Generate the scalar assignments (used to build a body). From 6f5ce16fb5d505b863a7dabea35168d05ff70e57 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 17:20:01 +0000 Subject: [PATCH 253/915] [mlir][OpDSL] Remove unused SoftPlus2DOp operation. The revision removes the SoftPlus2DOp operation that previously served as a test operation. It has been replaced by the elemwise_unary operation, which is now used to test unary log and exp functions. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120794 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 39934131c..2e1424a93 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -702,16 +702,3 @@ def fill_rng_2d( scaling = (max - min) * inv_range O[D.m, D.n] = TypeFn.cast_signed( T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) - - -@linalg_structured_op -def soft_plus_2d( - I=TensorDef(T, S.M, S.N), O=TensorDef(U, S.M, S.N, output=True)): - """Implements the soft plus operator. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - domain(D.m, D.n) - O[D.m, D.n] = \ - UnaryFn.log(TypeFn.cast_signed(U, const(1.0)) + UnaryFn.exp(TypeFn.cast_signed(U, I[D.m, D.n]))) From d1377d10c83b02fb5a2ede83df5449875c3b4c00 Mon Sep 17 00:00:00 2001 From: gysit Date: Tue, 8 Mar 2022 17:30:06 +0000 Subject: [PATCH 254/915] [mlir][OpDSL] Support pointwise ops with rank zero inputs. Allow pointwise operations to take rank zero input tensors similarly to scalar inputs. Use an empty indexing map to broadcast rank zero tensors to the iteration domain of the operation. Depends On D120734 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120807 --- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index ff5c405d7..93baef14b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -187,7 +187,11 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, if arg_def.operand_def.kind == OperandKind.SCALAR: indexing_maps.append(scalar_map) if arg_def.operand_def.is_tensor(): - indexing_maps.append(tensor_map) + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) From 0f02d5050de7801b40f2b72ca1220b8deada49e8 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 8 Mar 2022 12:40:12 -0500 Subject: [PATCH 255/915] [mlir][Linalg] Add a polymorphic linalg.copy operation With the recent improvements to OpDSL it is cheap to reintroduce a linalg.copy operation. This operation is needed in at least 2 cases: 1. for copies that may want to change the elemental type (e.g. cast, truncate, quantize, etc) 2. to specify new tensors that should bufferize to a copy operation. The linalg.generic form always folds away which is not always the right call. Differential Revision: https://reviews.llvm.org/D121230 --- .../linalg/opdsl/ops/core_named_ops.py | 550 ++++++++++-------- 1 file changed, 300 insertions(+), 250 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 2e1424a93..5774cbc6c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -7,11 +7,22 @@ @linalg_structured_op -def elemwise_unary( - I=TensorDef(T1), - O=TensorDef(U, output=True), - fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): +def copy(I=TensorDef(T1), + O=TensorDef(U, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Copies the tensor elementwise. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = cast(U, I[None]) + + +@linalg_structured_op +def elemwise_unary(I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the unary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -21,12 +32,11 @@ def elemwise_unary( @linalg_structured_op -def elemwise_binary( - lhs=TensorDef(T1), - rhs=TensorDef(T2), - O=TensorDef(U, output=True), - fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): +def elemwise_binary(lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Applies the binary function fun elementwise. Numeric casting is performed on the input operand, promoting it to the same @@ -36,11 +46,10 @@ def elemwise_binary( @linalg_structured_op -def matmul( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): +def matmul(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -52,10 +61,9 @@ def matmul( @linalg_structured_op -def matmul_unsigned( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): +def matmul_unsigned(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): """Performs an unsigned matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -68,12 +76,11 @@ def matmul_unsigned( @linalg_structured_op -def quantized_matmul( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, S.M, S.N, output=True)): +def quantized_matmul(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True)): """Performs a matrix multiplication of two 2D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -82,16 +89,16 @@ def quantized_matmul( matmul. """ domain(D.m, D.n, D.k) - C[D.m, D.n] += ( - TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( - TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + C[D.m, + D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - + TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) - + TypeFn.cast_signed(U, BZp)) @linalg_structured_op -def mmt4d( - lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), - accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): +def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): """Performs a matrix-matrix-transpose multiplication of two 4D inputs. Differences from linalg.matmul: @@ -110,10 +117,9 @@ def mmt4d( @linalg_structured_op -def batch_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): +def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True)): """Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -127,12 +133,11 @@ def batch_matmul( @linalg_structured_op -def quantized_batch_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, Batch, S.M, S.N, output=True)): +def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True)): """Performs a batched matrix multiplication of two 3D inputs. Numeric casting is performed on the operands to the inner multiply, promoting @@ -141,16 +146,15 @@ def quantized_batch_matmul( matmul. """ domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += ( - TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( - TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - + TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op -def matvec( - A=TensorDef(T1, S.M, S.N), - y=TensorDef(T2, S.N), - x=TensorDef(U, S.M, output=True)): +def matvec(A=TensorDef(T1, S.M, S.N), + y=TensorDef(T2, S.N), + x=TensorDef(U, S.M, output=True)): """Performs a matrix-vector multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -162,10 +166,9 @@ def matvec( @linalg_structured_op -def vecmat( - y=TensorDef(T1, S.M), - A=TensorDef(T2, S.M, S.N), - x=TensorDef(U, S.N, output=True)): +def vecmat(y=TensorDef(T1, S.M), + A=TensorDef(T2, S.M, S.N), + x=TensorDef(U, S.N, output=True)): """Performs a vector-matrix multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -177,10 +180,9 @@ def vecmat( @linalg_structured_op -def batch_matvec( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K), - C=TensorDef(U, Batch, S.M, output=True)): +def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True)): """Performs a batched matrix-vector multiplication. Numeric casting is performed on the operands to the inner multiply, promoting @@ -193,8 +195,8 @@ def batch_matvec( @linalg_structured_op -def dot( - A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, + output=True)): """Performs a dot product of two vectors to a scalar result. Numeric casting is performed on the operands to the inner multiply, promoting @@ -205,10 +207,9 @@ def dot( @linalg_structured_op -def conv_1d( - I=TensorDef(T1, S.OW + S.KW), - K=TensorDef(T2, S.KW), - O=TensorDef(U, S.OW, output=True)): +def conv_1d(I=TensorDef(T1, S.OW + S.KW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True)): """Performs 1-D convolution with no channels. Numeric casting is performed on the operands to the inner multiply, promoting @@ -221,10 +222,9 @@ def conv_1d( @linalg_structured_op -def conv_2d( - I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KH, S.KW), - O=TensorDef(U, S.OH, S.OW, output=True)): +def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True)): """Performs 2-D convolution with no channels. Numeric casting is performed on the operands to the inner multiply, promoting @@ -237,10 +237,9 @@ def conv_2d( @linalg_structured_op -def conv_3d( - I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KD, S.KH, S.KW), - O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): +def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): """Performs 3-D convolution with no channels. Numeric casting is performed on the operands to the inner multiply, promoting @@ -254,12 +253,11 @@ def conv_3d( @linalg_structured_op -def conv_1d_nwc_wcf( - I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): +def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -267,19 +265,18 @@ def conv_1d_nwc_wcf( """ implements(ConvolutionOpInterface) domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, - D.f] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) + O[D.n, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( + U, K[D.kw, D.c, D.f]) @linalg_structured_op -def conv_2d_nhwc_hwcf( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -297,15 +294,14 @@ def conv_2d_nhwc_hwcf( @linalg_structured_op -def conv_2d_nhwc_hwcf_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution with zero point offsets. Layout: @@ -321,19 +317,17 @@ def conv_2d_nhwc_hwcf_q( O[D.n, D.oh, D.ow, D.f] += (TypeFn.cast_signed( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - TypeFn.cast_signed(U, IZp)) * ( - TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - - TypeFn.cast_signed(U, KZp)) + TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed( + U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def conv_2d_nchw_fchw( - I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D convolution. Layout: @@ -351,13 +345,19 @@ def conv_2d_nchw_fchw( @linalg_structured_op -def conv_3d_ndhwc_dhwcf( - I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): +def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): """Performs 3-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -372,12 +372,12 @@ def conv_3d_ndhwc_dhwcf( @linalg_structured_op -def depthwise_conv_1d_nwc_wc( - I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KW, S.IC), - O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): +def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, + S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): """Performs depth-wise 1-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -392,13 +392,19 @@ def depthwise_conv_1d_nwc_wc( @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, + S.N, + S.OH, + S.OW, + S.IC, + output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, + S.DW, + default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -413,15 +419,23 @@ def depthwise_conv_2d_nhwc_hwc( @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, + S.N, + S.OH, + S.OW, + S.IC, + output=True), + strides=IndexAttrDef(S.SH, + S.SW, + default=[1, 1]), + dilations=IndexAttrDef(S.DH, + S.DW, + default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -437,13 +451,21 @@ def depthwise_conv_2d_nhwc_hwc_q( @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, + S.N, + S.OH, + S.OW, + S.IC, + S.CM, + output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, + 1]), + dilations=IndexAttrDef(S.DH, + S.DW, + default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -457,15 +479,25 @@ def depthwise_conv_2d_nhwc_hwcm( @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm_q( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, + S.N, + S.OH, + S.OW, + S.IC, + S.CM, + output=True), + strides=IndexAttrDef(S.SH, + S.SW, + default=[1, 1]), + dilations=IndexAttrDef(S.DH, + S.DW, + default=[1, 1])): """Performs depth-wise 2-D convolution. Numeric casting is performed on the operands to the inner multiply, promoting @@ -482,13 +514,12 @@ def depthwise_conv_2d_nhwc_hwcm_q( @linalg_structured_op -def pooling_nhwc_sum( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -501,13 +532,12 @@ def pooling_nhwc_sum( @linalg_structured_op -def pooling_nhwc_max( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -515,19 +545,21 @@ def pooling_nhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( - TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op -def pooling_nhwc_max_unsigned( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KH, + S.KW, + index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, + 1])): """Performs unsigned max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -535,19 +567,18 @@ def pooling_nhwc_max_unsigned( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.oh, D.ow, + D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op -def pooling_nchw_max( - I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -555,20 +586,17 @@ def pooling_nchw_max( """ implements(ConvolutionOpInterface) domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( - TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW,])) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) @linalg_structured_op -def pooling_nhwc_min( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -576,19 +604,21 @@ def pooling_nhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( - TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op -def pooling_nhwc_min_unsigned( - I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, - S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KH, + S.KW, + index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, + 1])): """Performs unsigned min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -596,19 +626,26 @@ def pooling_nhwc_min_unsigned( """ implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( - TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) - - -@linalg_structured_op -def pooling_ndhwc_sum( - I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): + O[D.n, D.oh, D.ow, + D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KD, + S.KH, + S.KW, + index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): """Performs 3D sum pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -622,13 +659,20 @@ def pooling_ndhwc_sum( @linalg_structured_op -def pooling_ndhwc_max( - I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): +def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KD, + S.KH, + S.KW, + index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): """Performs 3D max pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -636,20 +680,27 @@ def pooling_ndhwc_max( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( - TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) - - -@linalg_structured_op -def pooling_ndhwc_min( - I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1])): + O[D.n, D.od, D.oh, D.ow, + D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KD, + S.KH, + S.KW, + index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): """Performs 3D min pooling. Numeric casting is performed on the input operand, promoting it to the same @@ -657,10 +708,10 @@ def pooling_ndhwc_min( """ implements(ConvolutionOpInterface) domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( - TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) + O[D.n, D.od, D.oh, D.ow, + D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c])) @linalg_structured_op @@ -677,11 +728,10 @@ def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): @linalg_structured_op -def fill_rng_2d( - min=ScalarDef(F64), - max=ScalarDef(F64), - seed=ScalarDef(I32), - O=TensorDef(T, S.M, S.N, output=True)): +def fill_rng_2d(min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True)): """Fills the output tensor with pseudo random numbers. The operation generations pseudo random numbers using a linear congruential From 56cb327d992fd1ab0fb9179457bec075372528cc Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Thu, 10 Mar 2022 09:08:41 -0800 Subject: [PATCH 256/915] [mlir][linalg] Add a few unary operations. Add operations abs, ceil, floor, and neg to the C++ API and Python API. Add test cases. Reviewed By: gysit Differential Revision: https://reviews.llvm.org/D121339 --- .../linalg/opdsl/lang/comprehension.py | 4 ++++ .../dialects/linalg/opdsl/lang/emitter.py | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 47083de62..135f55ea5 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -274,6 +274,10 @@ class UnaryFn: """Unary function namespace.""" exp = UnaryFnType("exp") log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") class BinaryFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 93baef14b..2e71e561a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -390,6 +390,26 @@ def _unary_log(self, x: Value) -> Value: return math.LogOp(x).result raise NotImplementedError("Unsupported 'log' operand: {x}") + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + def _binary_add(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): return arith.AddFOp(lhs, rhs).result From 979bceb389227890a6552780c83f5f3ba1f8c13c Mon Sep 17 00:00:00 2001 From: Yun Long Date: Fri, 11 Mar 2022 10:50:10 +0100 Subject: [PATCH 257/915] [MLIR][python binding] Add OpaqueAttribute to python binding. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D120847 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 38 +++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index bef3b95a2..1093d50c8 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -320,6 +320,43 @@ class PyFlatSymbolRefAttribute } }; +class PyOpaqueAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, py::buffer buffer, PyType &type, + DefaultingPyMlirContext context) { + const py::buffer_info bufferInfo = buffer.request(); + intptr_t bufferSize = bufferInfo.size; + MlirAttribute attr = mlirOpaqueAttrGet( + context->get(), toMlirStringRef(dialectNamespace), bufferSize, + static_cast(bufferInfo.ptr), type); + return PyOpaqueAttribute(context->getRef(), attr); + }, + py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), + py::arg("context") = py::none(), "Gets an Opaque attribute."); + c.def_property_readonly( + "dialect_namespace", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque attribute as a string"); + c.def_property_readonly( + "data", + [](PyOpaqueAttribute &self) { + MlirStringRef stringRef = mlirOpaqueAttrGetData(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaqued attributes as a string"); + } +}; + class PyStringAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; @@ -862,6 +899,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseIntElementsAttribute::bind(m); PyDictAttribute::bind(m); PyFlatSymbolRefAttribute::bind(m); + PyOpaqueAttribute::bind(m); PyFloatAttribute::bind(m); PyIntegerAttribute::bind(m); PyStringAttribute::bind(m); From 45a10db864de655c3eded7debf5191e668bbdd68 Mon Sep 17 00:00:00 2001 From: chhzh123 Date: Sun, 13 Mar 2022 05:24:00 +0000 Subject: [PATCH 258/915] [MLIR][Python] Add SCFIfOp Python binding Current generated Python binding for the SCF dialect does not allow users to call IfOp to create if-else branches on their own. This PR sets up the default binding generation for scf.if operation to address this problem. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D121076 --- mlir/python/mlir/dialects/_scf_ops_ext.py | 41 +++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index a8924a750..3c3e67302 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -64,3 +64,44 @@ def inner_iter_args(self): To obtain the loop-carried operands, use `iter_args`. """ return self.body.arguments[1:] + + +class IfOp: + """Specialization for the SCF if op class.""" + + def __init__(self, + cond, + results_=[], + *, + hasElse=False, + loc=None, + ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super().__init__( + self.build_generic( + regions=2, + results=results, + operands=operands, + loc=loc, + ip=ip)) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] From af420e43d29caf603189862731745c9478ece008 Mon Sep 17 00:00:00 2001 From: gysit Date: Mon, 14 Mar 2022 10:45:04 +0000 Subject: [PATCH 259/915] [mlir][linalg] Replace linalg.fill by OpDSL variant. The revision removes the linalg.fill operation and renames the OpDSL generated linalg.fill_tensor operation to replace it. After the change, all named structured operations are defined via OpDSL and there are no handwritten operations left. A side-effect of the change is that the pretty printed form changes from: ``` %1 = linalg.fill(%cst, %0) : f32, tensor -> tensor ``` changes to ``` %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor ``` Additionally, the builder signature now takes input and output value ranges as it is the case for all other OpDSL operations: ``` rewriter.create(loc, val, output) ``` changes to ``` rewriter.create(loc, ValueRange{val}, ValueRange{output}) ``` All other changes remain minimal. In particular, the canonicalization patterns are the same and the `value()`, `output()`, and `result()` methods are now implemented by the FillOpInterface. Depends On D120726 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120728 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 16 ---------------- .../dialects/linalg/opdsl/ops/core_named_ops.py | 2 +- 2 files changed, 1 insertion(+), 17 deletions(-) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index 167a9232d..e3fb46055 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -20,22 +20,6 @@ def isa(cls: Type, ty: Type): return False -class FillOp: - """Extends the linalg.fill op.""" - - def __init__(self, output: Value, value: Value, *, loc=None, ip=None): - results = [] - if isa(RankedTensorType, output.type): - results = [output.type] - op = self.build_generic( - results=results, - operands=[_get_op_result_or_value(o) for o in [value, output]], - attributes=None, - loc=loc, - ip=ip) - OpView.__init__(self, op) - fill_builtin_region(self.operation) - class InitTensorOp: """Extends the linalg.init_tensor op.""" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 5774cbc6c..2c6291bad 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -715,7 +715,7 @@ def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, @linalg_structured_op -def fill_tensor(value=ScalarDef(T1), O=TensorDef(U, output=True)): +def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): """Fills the output tensor with the given value. Works for arbitrary ranked output tensors since the operation performs scalar From 895e786c1d3fad37eb162dccc4512e2e98fa709b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 14 Mar 2022 14:10:08 -0700 Subject: [PATCH 260/915] NFC: Remove unterminated string from Python pyi file. --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index affe54c3e..8bd9822f4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -168,7 +168,7 @@ class Attribute: @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context"": ... + def context(self) -> "Context": ... @property def type(self) -> "Type": ... From a9c9dde9291dd0cf115398db793f63f4647899ba Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 17 Mar 2022 00:17:53 +0100 Subject: [PATCH 261/915] [mlir] Add C API for ControlFlow dialect Add basic C API for the ControlFlow dialect. Follows the format of the other dialects. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D121867 --- mlir/include/mlir-c/Dialect/ControlFlow.h | 25 +++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 ++++++++ mlir/lib/CAPI/Dialect/ControlFlow.cpp | 14 +++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/ControlFlow.h create mode 100644 mlir/lib/CAPI/Dialect/ControlFlow.cpp diff --git a/mlir/include/mlir-c/Dialect/ControlFlow.h b/mlir/include/mlir-c/Dialect/ControlFlow.h new file mode 100644 index 000000000..1ca7054d6 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/ControlFlow.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/ControlFlow.h - C API for ControlFlow ------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_CONTROLFLOW_H +#define MLIR_C_DIALECT_CONTROLFLOW_H + +#include "mlir-c/Registration.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(ControlFlow, cf); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_CONTROLFLOW_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 9822f059e..37bc121f2 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -13,6 +13,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIAsync MLIRPass ) +add_mlir_upstream_c_api_library(MLIRCAPIControlFlow + ControlFlow.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRControlFlow +) + add_mlir_upstream_c_api_library(MLIRCAPIGPU GPU.cpp GPUPasses.cpp diff --git a/mlir/lib/CAPI/Dialect/ControlFlow.cpp b/mlir/lib/CAPI/Dialect/ControlFlow.cpp new file mode 100644 index 000000000..1e5b2de1c --- /dev/null +++ b/mlir/lib/CAPI/Dialect/ControlFlow.cpp @@ -0,0 +1,14 @@ +//===- ControlFlow.cpp - C Interface for ControlFlow dialect --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/ControlFlow.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(ControlFlow, cf, + mlir::cf::ControlFlowDialect) From 35e6f0b3f5296ccd4ebc50f0eafe8f232eac289d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 7 Mar 2022 19:16:03 -0800 Subject: [PATCH 262/915] [mlir] Move the Builtin FuncOp to the Func dialect This commit moves FuncOp out of the builtin dialect, and into the Func dialect. This move has been planned in some capacity from the moment we made FuncOp an operation (years ago). This commit handles the functional aspects of the move, but various aspects are left untouched to ease migration: func::FuncOp is re-exported into mlir to reduce the actual API churn, the assembly format still accepts the unqualified `func`. These temporary measures will remain for a little while to simplify migration before being removed. Differential Revision: https://reviews.llvm.org/D121266 --- mlir/python/mlir/dialects/_builtin_ops_ext.py | 212 ----------------- mlir/python/mlir/dialects/_func_ops_ext.py | 213 +++++++++++++++++- 2 files changed, 210 insertions(+), 215 deletions(-) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index a3a147469..b69163fa4 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -3,17 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Optional, Sequence, Union - - import inspect - from ..ir import * except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" -RESULT_ATTRIBUTE_NAME = "res_attrs" - class ModuleOp: """Specialization for the module op class.""" @@ -25,208 +18,3 @@ def __init__(self, *, loc=None, ip=None): @property def body(self): return self.regions[0].blocks[0] - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func(FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and (param.kind - == param.POSITIONAL_OR_KEYWORD or - param.kind == param.KEYWORD_ONLY): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None.") - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get(inputs=inputs, results=return_types) - func_op.attributes["type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), - call_args) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result - else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped - - return decorator diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 850562673..6932efd79 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -4,13 +4,16 @@ try: from ..ir import * - from .builtin import FuncOp from ._ods_common import get_default_loc_context as _get_default_loc_context - from typing import Any, List, Optional, Union + import inspect + + from typing import Any, List, Optional, Sequence, Union except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" class ConstantOp: """Specialization for the constant op class.""" @@ -23,6 +26,210 @@ def type(self): return self.results[0].type +class FuncOp: + """Specialization for the func op class.""" + + def __init__(self, + name, + type, + *, + visibility=None, + body_builder=None, + loc=None, + ip=None): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = StringAttr.get( + str(visibility)) if visibility is not None else None + super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError('External function does not have a body') + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError('The function already has an entry block!') + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func(FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and (param.kind + == param.POSITIONAL_OR_KEYWORD or + param.kind == param.KEYWORD_ONLY): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None.") + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get(inputs=inputs, results=return_types) + func_op.attributes["type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), + call_args) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator + class CallOp: """Specialization for the call op class.""" @@ -45,7 +252,7 @@ def __init__(self, For example - f = builtin.FuncOp("foo", ...) + f = func.FuncOp("foo", ...) func.CallOp(f, [args]) func.CallOp([result_types], "foo", [args]) From 439b44b1f34c78022e694a50536fba40d4a598c0 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 15 Mar 2022 17:36:15 -0700 Subject: [PATCH 263/915] [mlir:FunctionOpInterface] Rename the "type" attribute to "function_type" This removes any potential confusion with the `getType` accessors which correspond to SSA results of an operation, and makes it clear what the intent is (i.e. to represent the type of the function). Differential Revision: https://reviews.llvm.org/D121762 --- mlir/python/mlir/dialects/_func_ops_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 6932efd79..6fe3ff530 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -74,7 +74,7 @@ def body(self): @property def type(self): - return FunctionType(TypeAttr(self.attributes["type"]).value) + return FunctionType(TypeAttr(self.attributes["function_type"]).value) @property def visibility(self): @@ -211,7 +211,7 @@ def decorator(f): # Recompute the function type. return_types = [v.type for v in return_values] function_type = FunctionType.get(inputs=inputs, results=return_types) - func_op.attributes["type"] = TypeAttr.get(function_type) + func_op.attributes["function_type"] = TypeAttr.get(function_type) def emit_call_op(*call_args): call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), From 1ae3cfcb0bf6851970fbda4216dbb3f2cc07d4a0 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 13 Mar 2022 22:09:20 -0700 Subject: [PATCH 264/915] [mlir:PDL] Remove the ConstantParams support from native Constraints/Rewrites This support has never really worked well, and is incredibly clunky to use (it effectively creates two argument APIs), and clunky to generate (it isn't clear how we should actually expose this from PDL frontends). Treating these as just attribute arguments is much much cleaner in every aspect of the stack. If we need to optimize lots of constant parameters, it would be better to investigate internal representation optimizations (e.g. batch attribute creation), that do not affect the user (we want a clean external API). Differential Revision: https://reviews.llvm.org/D121569 --- mlir/python/mlir/dialects/_pdl_ops_ext.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index 364db5385..fb5b519c7 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -59,14 +59,12 @@ class ApplyNativeConstraintOp: def __init__(self, name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(name, args, params, loc=loc, ip=ip) + super().__init__(name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: @@ -76,14 +74,12 @@ def __init__(self, results: Sequence[Type], name: Union[str, StringAttr], args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): name = _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(results, name, args, params, loc=loc, ip=ip) + super().__init__(results, name, args, loc=loc, ip=ip) class AttributeOp: @@ -236,15 +232,13 @@ def __init__(self, root: Optional[Union[OpView, Operation, Value]] = None, name: Optional[Union[StringAttr, str]] = None, args: Sequence[Union[OpView, Operation, Value]] = [], - params: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, *, loc=None, ip=None): root = root if root is None else _get_value(root) name = name if name is None else _get_str_attr(name) args = _get_values(args) - params = params if params is None else _get_array_attr(params) - super().__init__(root, name, args, params, loc=loc, ip=ip) + super().__init__(root, name, args, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" From d79083b931cf1be3757f8a2687fa4070446ad84c Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Fri, 25 Mar 2022 13:38:13 +0100 Subject: [PATCH 265/915] Updated MLIR type stubs to work with pytype The diff is big, but there are in fact only three kinds of changes * ir.py had a synax error -- underminated [ * forward references are unnecessary in .pyi files (see https://github.com/python/typeshed/blob/9a76b13127ffa8365431dcc105fc111cdd267e7e/CONTRIBUTING.md?plain=1#L450-L454) * methods defined via .def_static() are now decorated with @staticmethod Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122300 --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 531 ++++++++++-------- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 2 +- 2 files changed, 312 insertions(+), 221 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 8bd9822f4..5bfb9202e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -95,31 +95,31 @@ __all__ = [ # Base classes: declared first to simplify declarations below. class _OperationBase: - def detach_from_parent(self) -> "OpView": ... + def detach_from_parent(self) -> OpView: ... def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ... - def move_after(self, other: "_OperationBase") -> None: ... - def move_before(self, other: "_OperationBase") -> None: ... + def move_after(self, other: _OperationBase) -> None: ... + def move_before(self, other: _OperationBase) -> None: ... def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ... def verify(self) -> bool: ... @overload - def __eq__(self, arg0: "_OperationBase") -> bool: ... + def __eq__(self, arg0: _OperationBase) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def attributes(self) -> "OpAttributeMap": ... + def attributes(self) -> OpAttributeMap: ... @property - def location(self) -> "Location": ... + def location(self) -> Location: ... @property - def operands(self) -> "OpOperandList": ... + def operands(self) -> OpOperandList: ... @property - def regions(self) -> "RegionSequence": ... + def regions(self) -> RegionSequence: ... @property - def result(self) -> "OpResult": ... + def result(self) -> OpResult: ... @property - def results(self) -> "OpResultList": ... + def results(self) -> OpResultList: ... # TODO: Auto-generated. Audit and fix. class AffineExpr: @@ -154,69 +154,72 @@ class AffineExpr: def context(self) -> object: ... class Attribute: - def __init__(self, cast_from_type: "Attribute") -> None: ... - def _CAPICreate(self) -> "Attribute": ... + def __init__(self, cast_from_type: Attribute) -> None: ... + def _CAPICreate(self) -> Attribute: ... def dump(self) -> None: ... def get_named(self, *args, **kwargs) -> Any: ... @staticmethod - def parse(asm: str, context: Optional["Context"] = None) -> Any: ... + def parse(asm: str, context: Optional[Context] = None) -> Any: ... @overload - def __eq__(self, arg0: "Attribute") -> bool: ... + def __eq__(self, arg0: Attribute) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... @property - def type(self) -> "Type": ... + def type(self) -> Type: ... class Type: - def __init__(self, cast_from_type: "Type") -> None: ... + def __init__(self, cast_from_type: Type) -> None: ... def _CAPICreate(self) -> Type: ... def dump(self) -> None: ... @staticmethod - def parse(asm: str, context: Optional["Context"] = None) -> "Type": ... + def parse(asm: str, context: Optional[Context] = None) -> Type: ... @overload - def __eq__(self, arg0: "Type") -> bool: ... + def __eq__(self, arg0: Type) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... class Value: - def _CAPICreate(self) -> "Value": ... + def _CAPICreate(self) -> Value: ... def dump(self) -> None: ... @overload - def __eq__(self, arg0: "Value") -> bool: ... + def __eq__(self, arg0: Value) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... @property - def owner(self) -> "_OperationBase": ... + def owner(self) -> _OperationBase: ... @property - def type(self) -> "Type": ... + def type(self) -> Type: ... # Classes with no particular order sensitivity in alpha order. # TODO: Auto-generated. Audit and fix. class AffineAddExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineAddExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class AffineBinaryExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def lhs(self) -> AffineExpr: ... @property @@ -225,22 +228,28 @@ class AffineBinaryExpr(AffineExpr): # TODO: Auto-generated. Audit and fix. class AffineCeilDivExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineCeilDivExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class AffineConstantExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineConstantExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def value(self) -> int: ... # TODO: Auto-generated. Audit and fix. class AffineDimExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineDimExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def position(self) -> int: ... @@ -257,22 +266,29 @@ class AffineExprList: # TODO: Auto-generated. Audit and fix. class AffineFloorDivExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + def get(*args, **kwargs) -> AffineFloorDivExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class AffineMap: def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> AffineMap: ... - def compress_unused_symbols(self, *args, **kwargs) -> Any: ... + @staticmethod + def compress_unused_symbols(*args, **kwargs) -> Any: ... def dump(self) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def get_constant(self, *args, **kwargs) -> Any: ... - def get_empty(self, *args, **kwargs) -> Any: ... - def get_identity(self, *args, **kwargs) -> Any: ... - def get_major_submap(self, n_results: int) -> AffineMap: ... - def get_minor_identity(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineMap: ... + @staticmethod + def get_constant(*args, **kwargs) -> AffineMap: ... + @staticmethod + def get_empty(*args, **kwargs) -> AffineMap: ... + @staticmethod + def get_identity(*args, **kwargs) -> AffineMap: ... + @staticmethod + def get_minor_identity(*args, **kwargs) -> AffineMap: ... def get_minor_submap(self, n_results: int) -> AffineMap: ... + def get_major_submap(self, n_results: int) -> AffineMap: ... def get_permutation(self, *args, **kwargs) -> Any: ... def get_submap(self, result_positions: List[int]) -> AffineMap: ... def replace(self, expr: AffineExpr, replacement: AffineExpr, n_result_dims: int, n_result_syms: int) -> AffineMap: ... @@ -301,36 +317,46 @@ class AffineMap: # TODO: Auto-generated. Audit and fix. class AffineMapAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineMapAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... # TODO: Auto-generated. Audit and fix. class AffineModExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineModExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class AffineMulExpr(AffineBinaryExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineMulExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class AffineSymbolExpr(AffineExpr): def __init__(self, expr: AffineExpr) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> AffineSymbolExpr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def position(self) -> int: ... # TODO: Auto-generated. Audit and fix. class ArrayAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> ArrayAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... def __add__(self, arg0: list) -> ArrayAttr: ... def __getitem__(self, arg0: int) -> Attribute: ... def __iter__(self) -> Any: ... @@ -347,64 +373,69 @@ class ArrayAttributeIterator: # TODO: Auto-generated. Audit and fix. class BF16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> BF16Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... class Block: __hash__: ClassVar[None] = ... def append(self, operation: _OperationBase) -> None: ... - def create_after(self, *args: "Type") -> "Block": ... + def create_after(self, *args: Type) -> Block: ... @staticmethod - def create_at_start(parent: "Region", arg_types: List["Type"]) -> "Block": ... - def create_before(self, *args: "Type") -> "Block": ... + def create_at_start(parent: Region, arg_types: List[Type]) -> Block: ... + def create_before(self, *args: Type) -> Block: ... @overload - def __eq__(self, arg0: "Block") -> bool: ... + def __eq__(self, arg0: Block) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __iter__(self) -> Any: ... @property - def arguments(self) -> "BlockArgumentList": ... + def arguments(self) -> BlockArgumentList: ... @property - def operations(self) -> "OperationList": ... + def operations(self) -> OperationList: ... @property - def owner(self) -> "OpView": ... + def owner(self) -> OpView: ... @property - def region(self) -> "Region": ... + def region(self) -> Region: ... class BlockArgument(Value): - def isinstance(self, *args, **kwargs) -> Any: ... - def set_type(self, type: "Type") -> None: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + def set_type(self, type: Type) -> None: ... @property def arg_number(self) -> int: ... @property - def owner(self) -> "Block": ... + def owner(self) -> Block: ... class BlockArgumentList: - def __add__(self, arg0: "BlockArgumentList") -> List["BlockArgument"]: ... + def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... @overload - def __getitem__(self, arg0: int) -> "BlockArgument": ... + def __getitem__(self, arg0: int) -> BlockArgument: ... @overload - def __getitem__(self, arg0: slice) -> "BlockArgumentList": ... + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... def __len__(self) -> int: ... @property - def types(self) -> List["Type"]: ... + def types(self) -> List[Type]: ... class BlockIterator: def __init__(self, *args, **kwargs) -> None: ... - def __iter__(self) -> "BlockIterator": ... - def __next__(self) -> "Block": ... + def __iter__(self) -> BlockIterator: ... + def __next__(self) -> Block: ... class BlockList: - def append(self, *args) -> "Block": ... - def __getitem__(self, arg0: int) -> "Block": ... - def __iter__(self) -> "BlockIterator": ... + def append(self, *args) -> Block: ... + def __getitem__(self, arg0: int) -> Block: ... + def __iter__(self) -> BlockIterator: ... def __len__(self) -> int: ... # TODO: Auto-generated. Audit and fix. class BoolAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> BoolAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -413,40 +444,45 @@ class BoolAttr(Attribute): # TODO: Auto-generated. Audit and fix. class ComplexType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> ComplexType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def element_type(self) -> Type: ... class Context: - current: ClassVar["Context"] = ... # read-only + current: ClassVar[Context] = ... # read-only allow_unregistered_dialects: bool def __init__(self) -> None: ... def _CAPICreate(self) -> object: ... - def _get_context_again(self) -> "Context": ... + def _get_context_again(self) -> Context: ... @staticmethod def _get_live_count() -> int: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... - def attach_diagnostic_handler(self, callback: Callable[["Diagnostic"], bool]) -> "DiagnosticHandler": ... + def attach_diagnostic_handler(self, callback: Callable[[Diagnostic], bool]) -> DiagnosticHandler: ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(dialect_name: str) -> "DialectDescriptor": ... + def get_dialect_descriptor(dialect_name: str) -> DialectDescriptor: ... def is_registered_operation(self, operation_name: str) -> bool: ... - def __enter__(self) -> "Context": ... + def __enter__(self) -> Context: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def d(self) -> "Dialects": ... + def d(self) -> Dialects: ... @property - def dialects(self) -> "Dialects": ... + def dialects(self) -> Dialects: ... # TODO: Auto-generated. Audit and fix. class DenseElementsAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def get_splat(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> DenseElementsAttr: ... + @staticmethod + def get_splat(*args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... def __len__(self) -> int: ... @property def is_splat(self) -> bool: ... @@ -456,7 +492,10 @@ class DenseElementsAttr(Attribute): # TODO: Auto-generated. Audit and fix. class DenseFPElementsAttr(DenseElementsAttr): def __init__(self, cast_from_attr: Attribute) -> None: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> DenseFPElementsAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... def __getitem__(self, arg0: int) -> float: ... @property def type(self) -> Type: ... @@ -464,15 +503,18 @@ class DenseFPElementsAttr(DenseElementsAttr): # TODO: Auto-generated. Audit and fix. class DenseIntElementsAttr(DenseElementsAttr): def __init__(self, cast_from_attr: Attribute) -> None: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> DenseIntElementsAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... def __getitem__(self, arg0: int) -> int: ... @property def type(self) -> Type: ... class Dialect: - def __init__(self, descriptor: "DialectDescriptor") -> None: ... + def __init__(self, descriptor: DialectDescriptor) -> None: ... @property - def descriptor(self) -> "DialectDescriptor": ... + def descriptor(self) -> DialectDescriptor: ... class DialectDescriptor: @property @@ -480,18 +522,18 @@ class DialectDescriptor: class Dialects: def __init__(self, *args, **kwargs) -> None: ... - def __getattr__(self, arg0: str) -> "Dialect": ... - def __getitem__(self, arg0: str) -> "Dialect": ... + def __getattr__(self, arg0: str) -> Dialect: ... + def __getitem__(self, arg0: str) -> Dialect: ... class Diagnostic: @property - def severity(self) -> "DiagnosticSeverity": ... + def severity(self) -> DiagnosticSeverity: ... @property - def location(self) -> "Location": ... + def location(self) -> Location: ... @property def message(self) -> str: ... @property - def notes(self) -> Tuple["Diagnostic"]: ... + def notes(self) -> Tuple[Diagnostic]: ... class DiagnosticHandler: def detach(self) -> None: ... @@ -499,20 +541,22 @@ class DiagnosticHandler: def attached(self) -> bool: ... @property def had_error(self) -> bool: ... - def __enter__(self) -> "DiagnosticHandler": ... + def __enter__(self) -> DiagnosticHandler: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... class DiagnosticSeverity: - ERROR: "DiagnosticSeverity" - WARNING: "DiagnosticSeverity" - NOTE: "DiagnosticSeverity" - REMARK: "DiagnosticSeverity" + ERROR: DiagnosticSeverity + WARNING: DiagnosticSeverity + NOTE: DiagnosticSeverity + REMARK: DiagnosticSeverity # TODO: Auto-generated. Audit and fix. class DictAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> DictAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... def __contains__(self, arg0: str) -> bool: ... @overload def __getitem__(self, arg0: str) -> Attribute: ... @@ -525,26 +569,34 @@ class DictAttr(Attribute): # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> F16Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class F32Type(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> F32Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class F64Type(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> F64Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class FlatSymbolRefAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> FlatSymbolRefAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -553,10 +605,14 @@ class FlatSymbolRefAttr(Attribute): # TODO: Auto-generated. Audit and fix. class FloatAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def get_f32(self, *args, **kwargs) -> Any: ... - def get_f64(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> FloatAttr: ... + @staticmethod + def get_f32(*args, **kwargs) -> FloatAttr: ... + @staticmethod + def get_f64(*args, **kwargs) -> FloatAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -565,8 +621,10 @@ class FloatAttr(Attribute): # TODO: Auto-generated. Audit and fix. class FunctionType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> FunctionType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def inputs(self) -> list: ... @property @@ -575,38 +633,42 @@ class FunctionType(Type): # TODO: Auto-generated. Audit and fix. class IndexType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> IndexType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... class InferTypeOpInterface: - def __init__(self, object: object, context: Optional["Context"] = None) -> None: ... - def inferReturnTypes(self, operands: Optional[List["Value"]] = None, attributes: Optional["Attribute"] = None, regions: Optional[List["Region"]] = None, context: Optional["Context"] = None, loc: Optional["Location"] = None) -> List[Type]: ... + def __init__(self, object: object, context: Optional[Context] = None) -> None: ... + def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @property - def operation(self) -> "_OperationBase": ... + def operation(self) -> Operation: ... @property - def opview(self) -> "OpView": ... + def opview(self) -> OpView: ... class InsertionPoint: - current: ClassVar["InsertionPoint"] = ... # read-only + current: ClassVar[InsertionPoint] = ... # read-only @overload - def __init__(self, block: "Block") -> None: ... + def __init__(self, block: Block) -> None: ... @overload - def __init__(self, beforeOperation: "_OperationBase") -> None: ... + def __init__(self, beforeOperation: _OperationBase) -> None: ... @staticmethod - def at_block_begin(block: "Block") -> "InsertionPoint": ... + def at_block_begin(block: Block) -> InsertionPoint: ... @staticmethod - def at_block_terminator(block: "Block") -> "InsertionPoint": ... - def insert(self, operation: "_OperationBase") -> None: ... - def __enter__(self) -> "InsertionPoint": ... + def at_block_terminator(block: Block) -> InsertionPoint: ... + def insert(self, operation: _OperationBase) -> None: ... + def __enter__(self) -> InsertionPoint: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property - def block(self) -> "Block": ... + def block(self) -> Block: ... # TODO: Auto-generated. Audit and fix. class IntegerAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> IntegerAttr: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -617,8 +679,10 @@ class IntegerSet: def __init__(self, *args, **kwargs) -> None: ... def _CAPICreate(self) -> IntegerSet: ... def dump(self) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def get_empty(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> IntegerSet: ... + @staticmethod + def get_empty(*args, **kwargs) -> IntegerSet: ... def get_replaced(self, dim_exprs: list, symbol_exprs: list, num_result_dims: int, num_result_symbols: int) -> IntegerSet: ... @overload def __eq__(self, arg0: IntegerSet) -> bool: ... @@ -665,10 +729,14 @@ class IntegerSetConstraintList: # TODO: Auto-generated. Audit and fix. class IntegerType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get_signed(self, *args, **kwargs) -> Any: ... - def get_signless(self, *args, **kwargs) -> Any: ... - def get_unsigned(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get_signed(*args, **kwargs) -> IntegerType: ... + @staticmethod + def get_signless(*args, **kwargs) -> IntegerType: ... + @staticmethod + def get_unsigned(*args, **kwargs) -> IntegerType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def is_signed(self) -> bool: ... @property @@ -679,35 +747,37 @@ class IntegerType(Type): def width(self) -> int: ... class Location: - current: ClassVar["Location"] = ... # read-only + current: ClassVar[Location] = ... # read-only __hash__: ClassVar[None] = ... - def _CAPICreate(self) -> "Location": ... + def _CAPICreate(self) -> Location: ... @staticmethod - def callsite(callee: "Location", frames: Sequence["Location"], context: Optional["Context"] = None) -> "Location": ... + def callsite(callee: Location, frames: Sequence[Location], context: Optional[Context] = None) -> Location: ... @staticmethod - def file(filename: str, line: int, col: int, context: Optional["Context"] = None) -> "Location": ... + def file(filename: str, line: int, col: int, context: Optional[Context] = None) -> Location: ... @staticmethod - def fused(locations: Sequence["Location"], metadata: Optional["Attribute"] = None, context: Optional["Context"] = None) -> "Location": ... + def fused(locations: Sequence[Location], metadata: Optional[Attribute] = None, context: Optional[Context] = None) -> Location: ... @staticmethod - def name(name: str, childLoc: Optional["Location"] = None, context: Optional["Context"] = None) -> "Location": ... + def name(name: str, childLoc: Optional[Location] = None, context: Optional[Context] = None) -> Location: ... @staticmethod - def unknown(context: Optional["Context"] = None) -> Any: ... - def __enter__(self) -> "Location": ... + def unknown(context: Optional[Context] = None) -> Any: ... + def __enter__(self) -> Location: ... @overload - def __eq__(self, arg0: "Location") -> bool: ... + def __eq__(self, arg0: Location) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... # TODO: Auto-generated. Audit and fix. class MemRefType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> MemRefType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def affine_map(self) -> AffineMap: ... @property @@ -717,143 +787,148 @@ class MemRefType(ShapedType): class Module: def _CAPICreate(self) -> object: ... - def create(loc: Optional["Location"] = None) -> "Module": ... + def create(loc: Optional[Location] = None) -> Module: ... def dump(self) -> None: ... @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> "Module": ... + def parse(asm: str, context: Optional[Context] = None) -> Module: ... @property def _CAPIPtr(self) -> object: ... @property - def body(self) -> "Block": ... + def body(self) -> Block: ... @property def context(self) -> object: ... @property - def operation(self) -> "_OperationBase": ... + def operation(self) -> Operation: ... class NamedAttribute: @property - def attr(self) -> "Attribute": ... + def attr(self) -> Attribute: ... @property def name(self) -> str: ... # TODO: Auto-generated. Audit and fix. class NoneType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> NoneType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... class OpAttributeMap: def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... @overload - def __getitem__(self, arg0: str) -> "Attribute": ... + def __getitem__(self, arg0: str) -> Attribute: ... @overload - def __getitem__(self, arg0: int) -> "NamedAttribute": ... + def __getitem__(self, arg0: int) -> NamedAttribute: ... def __len__(self) -> int: ... - def __setitem__(self, arg0: str, arg1: "Attribute") -> None: ... + def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... class OpOperandList: - def __add__(self, arg0: "OpOperandList") -> List[Value]: ... + def __add__(self, arg0: OpOperandList) -> List[Value]: ... @overload def __getitem__(self, arg0: int) -> Value: ... @overload - def __getitem__(self, arg0: slice) -> "OpOperandList": ... + def __getitem__(self, arg0: slice) -> OpOperandList: ... def __len__(self) -> int: ... def __setitem__(self, arg0: int, arg1: Value) -> None: ... class OpResult(Value): def __init__(self, value: Value) -> None: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property - def owner(self) -> "_OperationBase": ... + def owner(self) -> _OperationBase: ... @property def result_number(self) -> int: ... class OpResultList: - def __add__(self, arg0: "OpResultList") -> List["OpResult"]: ... + def __add__(self, arg0: OpResultList) -> List[OpResult]: ... @overload - def __getitem__(self, arg0: int) -> "OpResult": ... + def __getitem__(self, arg0: int) -> OpResult: ... @overload - def __getitem__(self, arg0: slice) -> "OpResultList": ... + def __getitem__(self, arg0: slice) -> OpResultList: ... def __len__(self) -> int: ... @property - def types(self) -> List["Type"]: ... + def types(self) -> List[Type]: ... class OpView(_OperationBase): _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... _ODS_REGIONS: ClassVar[tuple] = ... _ODS_RESULT_SEGMENTS: ClassVar[None] = ... - def __init__(self, operation: "_OperationBase") -> None: ... + def __init__(self, operation: _OperationBase) -> None: ... @classmethod - def build_generic(cls, results: Optional[Sequence["Type"]] = None, - operands: Optional[Sequence["Value"]] = None, - attributes: Optional[Dict[str, "Attribute"]] = None, - successors: Optional[Sequence["Block"]] = None, + def build_generic(cls, results: Optional[Sequence[Type]] = None, + operands: Optional[Sequence[Value]] = None, + attributes: Optional[Dict[str, Attribute]] = None, + successors: Optional[Sequence[Block]] = None, regions: Optional[int] = None, - loc: Optional["Location"] = None, - ip: Optional["InsertionPoint"] = None) -> "_OperationBase": ... + loc: Optional[Location] = None, + ip: Optional[InsertionPoint] = None) -> _OperationBase: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... @property - def operation(self) -> "_OperationBase": ... + def operation(self) -> _OperationBase: ... class Operation(_OperationBase): def _CAPICreate(self) -> object: ... @staticmethod - def create(name: str, results: Optional[Sequence["Type"]] = None, - operands: Optional[Sequence["Value"]] = None, - attributes: Optional[Dict[str, "Attribute"]] = None, - successors: Optional[Sequence["Block"]] = None, + def create(name: str, results: Optional[Sequence[Type]] = None, + operands: Optional[Sequence[Value]] = None, + attributes: Optional[Dict[str, Attribute]] = None, + successors: Optional[Sequence[Block]] = None, regions: int = 0, - loc: Optional["Location"] = None, - ip: Optional["InsertionPoint"] = None) -> "_OperationBase": ... + loc: Optional[Location] = None, + ip: Optional[InsertionPoint] = None) -> _OperationBase: ... def erase(self) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> "Context": ... + def context(self) -> Context: ... @property def name(self) -> str: ... @property - def opview(self) -> "OpView": ... + def opview(self) -> OpView: ... @property - def parent(self) -> Optional["_OperationBase"]: ... + def parent(self) -> Optional[_OperationBase]: ... class OperationIterator: - def __iter__(self) -> "OperationIterator": ... - def __next__(self) -> "OpView": ... + def __iter__(self) -> OperationIterator: ... + def __next__(self) -> OpView: ... class OperationList: - def __getitem__(self, arg0: int) -> "OpView": ... - def __iter__(self) -> "OperationIterator": ... + def __getitem__(self, arg0: int) -> OpView: ... + def __iter__(self) -> OperationIterator: ... def __len__(self) -> int: ... # TODO: Auto-generated. Audit and fix. class RankedTensorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> RankedTensorType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def encoding(self) -> Optional[Attribute]: ... class Region: __hash__: ClassVar[None] = ... @overload - def __eq__(self, arg0: "Region") -> bool: ... + def __eq__(self, arg0: Region) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... - def __iter__(self) -> "BlockIterator": ... + def __iter__(self) -> BlockIterator: ... @property - def blocks(self) -> "BlockList": ... + def blocks(self) -> BlockList: ... @property - def owner(self) -> "OpView": ... + def owner(self) -> OpView: ... class RegionIterator: - def __iter__(self) -> "RegionIterator": ... - def __next__(self) -> "Region": ... + def __iter__(self) -> RegionIterator: ... + def __next__(self) -> Region: ... class RegionSequence: - def __getitem__(self, arg0: int) -> "Region": ... + def __getitem__(self, arg0: int) -> Region: ... def __len__(self) -> int: ... # TODO: Auto-generated. Audit and fix. @@ -863,7 +938,8 @@ class ShapedType(Type): def is_dynamic_dim(self, dim: int) -> bool: ... def is_dynamic_size(self, *args, **kwargs) -> Any: ... def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def element_type(self) -> Type: ... @property @@ -878,9 +954,12 @@ class ShapedType(Type): # TODO: Auto-generated. Audit and fix. class StringAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def get_typed(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> Any: ... + @staticmethod + def get_typed(*args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -901,25 +980,29 @@ class SymbolTable: @staticmethod def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... @staticmethod - def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None) -> None: ... + def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None]) -> None: ... def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... - def __getitem__(self, arg0: str) -> "OpView": ... + def __getitem__(self, arg0: str) -> OpView: ... # TODO: Auto-generated. Audit and fix. class TupleType(Type): def __init__(self, cast_from_type: Type) -> None: ... - def get_tuple(self, *args, **kwargs) -> Any: ... + @staticmethod + def get_tuple(*args, **kwargs) -> TupleType: ... def get_type(self, pos: int) -> Type: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def num_types(self) -> int: ... # TODO: Auto-generated. Audit and fix. class TypeAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... @property @@ -928,30 +1011,38 @@ class TypeAttr(Attribute): # TODO: Auto-generated. Audit and fix. class UnitAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> Any: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def type(self) -> Type: ... # TODO: Auto-generated. Audit and fix. class UnrankedMemRefType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> UnrankedMemRefType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... @property def memory_space(self) -> Attribute: ... # TODO: Auto-generated. Audit and fix. class UnrankedTensorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> UnrankedTensorType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... # TODO: Auto-generated. Audit and fix. class VectorType(ShapedType): def __init__(self, cast_from_type: Type) -> None: ... - def get(self, *args, **kwargs) -> Any: ... - def isinstance(self, *args, **kwargs) -> Any: ... + @staticmethod + def get(*args, **kwargs) -> VectorType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... class _GlobalDebug: flag: ClassVar[bool] = ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 728f46418..44d22255e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -19,7 +19,7 @@ class PassManager: def enable_ir_printing(self) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod - def parse(pipeline: str, context: Optional[_ir.Context] = None) -> "PassManager": ... + def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... def run(self, module: _ir.Module) -> None: ... @property def _CAPIPtr(self) -> object: ... From 9601957671bfea32f39dc4d72da017a9cea2bd39 Mon Sep 17 00:00:00 2001 From: Dominik Grewe Date: Mon, 28 Mar 2022 15:45:40 +0200 Subject: [PATCH 266/915] Expose MlirOperationClone in Python bindings. Expose MlirOperationClone in Python bindings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122526 --- mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++++++++++++--------- mlir/lib/Bindings/Python/IRModule.h | 3 +++ 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 621c09502..1225c2648 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1075,6 +1075,21 @@ py::object PyOperation::createFromCapsule(py::object capsule) { .releaseObject(); } +static void maybeInsertOperation(PyOperationRef &op, + const py::object &maybeIp) { + // InsertPoint active? + if (!maybeIp.is(py::cast(false))) { + PyInsertionPoint *ip; + if (maybeIp.is_none()) { + ip = PyThreadContextEntry::getDefaultInsertionPoint(); + } else { + ip = py::cast(maybeIp); + } + if (ip) + ip->insert(*op.get()); + } +} + py::object PyOperation::create( const std::string &name, llvm::Optional> results, llvm::Optional> operands, @@ -1192,22 +1207,20 @@ py::object PyOperation::create( MlirOperation operation = mlirOperationCreate(&state); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); - - // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { - PyInsertionPoint *ip; - if (maybeIp.is_none()) { - ip = PyThreadContextEntry::getDefaultInsertionPoint(); - } else { - ip = py::cast(maybeIp); - } - if (ip) - ip->insert(*created.get()); - } + maybeInsertOperation(created, maybeIp); return created->createOpView(); } +py::object PyOperation::clone(const py::object &maybeIp) { + MlirOperation clonedOperation = mlirOperationClone(operation); + PyOperationRef cloned = + PyOperation::createDetached(getContext(), clonedOperation); + maybeInsertOperation(cloned, maybeIp); + + return cloned->createOpView(); +} + py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); @@ -2616,6 +2629,7 @@ void mlir::python::populateIRCore(py::module &m) { return py::none(); }) .def("erase", &PyOperation::erase) + .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b1424a994..2046ce0c1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -575,6 +575,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Clones this operation. + pybind11::object clone(const pybind11::object &ip); + private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, From ccaba2ea73bb4627e798a180e2bdf1f9c6a9e9fe Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 31 Mar 2022 11:55:51 +0200 Subject: [PATCH 267/915] Fixed mypy type errors in MLIR Python type stubs This commit fixes or disables all errors reported by python3 -m mypy -p mlir --show-error-codes Note that unhashable types cannot be currently expressed in a way compatible with typeshed. See https://github.com/python/typeshed/issues/6243 for details. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122790 --- mlir/python/mlir/_mlir_libs/__init__.py | 8 +++++++- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 26 ++++++++++++++++-------- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 4e2e5f453..23bc50267 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -2,13 +2,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Sequence +from typing import Any, Sequence import os _this_dir = os.path.dirname(__file__) +# These submodules have no type stubs and are thus opaque to the type checker. +_mlirConversions: Any +_mlirTransforms: Any +_mlirAllPassesRegistration: Any + + def get_lib_dirs() -> Sequence[str]: """Gets the lib directory for linking to shared libraries. diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 5bfb9202e..7b1667fa5 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -7,7 +7,10 @@ # * Local edits to signatures and types that MyPy did not auto detect (or # detected incorrectly). -from typing import Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple +from typing import ( + Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, + Type as _Type, TypeVar +) from typing import overload @@ -121,6 +124,8 @@ class _OperationBase: @property def results(self) -> OpResultList: ... +_TOperation = TypeVar("_TOperation", bound=_OperationBase) + # TODO: Auto-generated. Audit and fix. class AffineExpr: def __init__(self, *args, **kwargs) -> None: ... @@ -379,7 +384,7 @@ class BF16Type(Type): def isinstance(arg: Any) -> bool: ... class Block: - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore def append(self, operation: _OperationBase) -> None: ... def create_after(self, *args: Type) -> Block: ... @staticmethod @@ -406,7 +411,7 @@ class BlockArgument(Value): @property def arg_number(self) -> int: ... @property - def owner(self) -> Block: ... + def owner(self) -> Block: ... # type: ignore[override] class BlockArgumentList: def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... @@ -463,7 +468,7 @@ class Context: def _get_live_operation_count(self) -> int: ... def attach_diagnostic_handler(self, callback: Callable[[Diagnostic], bool]) -> DiagnosticHandler: ... def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(dialect_name: str) -> DialectDescriptor: ... + def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: ... def is_registered_operation(self, operation_name: str) -> bool: ... def __enter__(self) -> Context: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @@ -748,7 +753,7 @@ class IntegerType(Type): class Location: current: ClassVar[Location] = ... # read-only - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore def _CAPICreate(self) -> Location: ... @staticmethod def callsite(callee: Location, frames: Sequence[Location], context: Optional[Context] = None) -> Location: ... @@ -787,6 +792,7 @@ class MemRefType(ShapedType): class Module: def _CAPICreate(self) -> object: ... + @staticmethod def create(loc: Optional[Location] = None) -> Module: ... def dump(self) -> None: ... @staticmethod @@ -858,17 +864,19 @@ class OpView(_OperationBase): _ODS_RESULT_SEGMENTS: ClassVar[None] = ... def __init__(self, operation: _OperationBase) -> None: ... @classmethod - def build_generic(cls, results: Optional[Sequence[Type]] = None, + def build_generic( + cls: _Type[_TOperation], + results: Optional[Sequence[Type]] = None, operands: Optional[Sequence[Value]] = None, attributes: Optional[Dict[str, Attribute]] = None, successors: Optional[Sequence[Block]] = None, regions: Optional[int] = None, loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None) -> _OperationBase: ... + ip: Optional[InsertionPoint] = None) -> _TOperation: ... @property def context(self) -> Context: ... @property - def operation(self) -> _OperationBase: ... + def operation(self) -> Operation: ... class Operation(_OperationBase): def _CAPICreate(self) -> object: ... @@ -912,7 +920,7 @@ class RankedTensorType(ShapedType): def encoding(self) -> Optional[Attribute]: ... class Region: - __hash__: ClassVar[None] = ... + __hash__: ClassVar[None] = ... # type: ignore @overload def __eq__(self, arg0: Region) -> bool: ... @overload From 0cf08ebc93d350e3853ddcfc9485c4987666379e Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 31 Mar 2022 11:57:07 +0200 Subject: [PATCH 268/915] Added an empty __init__.py file to the MLIR Python bindings While not strictly required after PEP-420, it is better to have one, since not all tooling supports implicit namespace packages. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122794 --- mlir/python/mlir/__init__.py | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 mlir/python/mlir/__init__.py diff --git a/mlir/python/mlir/__init__.py b/mlir/python/mlir/__init__.py new file mode 100644 index 000000000..a0f4e3d5e --- /dev/null +++ b/mlir/python/mlir/__init__.py @@ -0,0 +1,3 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception From bdbeb7a20ef5b15eb3b24c7d898b1c009eef2d30 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 31 Mar 2022 14:27:49 +0200 Subject: [PATCH 269/915] Fixed the type of context in type stubs for MLIR Python bindings Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D122795 --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 7b1667fa5..20d9c919c 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -156,7 +156,7 @@ class AffineExpr: @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> Context: ... class Attribute: def __init__(self, cast_from_type: Attribute) -> None: ... @@ -305,7 +305,7 @@ class AffineMap: @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> object: ... + def context(self) -> Context: ... @property def is_permutation(self) -> bool: ... @property @@ -699,7 +699,7 @@ class IntegerSet: @property def constraints(self) -> Any: ... @property - def context(self) -> object: ... + def context(self) -> Context: ... @property def is_canonical_empty(self) -> bool: ... @property @@ -802,7 +802,7 @@ class Module: @property def body(self) -> Block: ... @property - def context(self) -> object: ... + def context(self) -> Context: ... @property def operation(self) -> Operation: ... From 51c1842f6e09f1032bd0584623c38deed2835469 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 31 Mar 2022 20:01:12 +0200 Subject: [PATCH 270/915] Revert "Added an empty __init__.py file to the MLIR Python bindings" This reverts commit 0cf08ebc93d350e3853ddcfc9485c4987666379e. Post-commit review pointed out that adding this file will require the entire Python tree (including out-of-tree projects) to come from the same directory, which might be problematic in non-default installations. Reverting pending further discussion. --- mlir/python/mlir/__init__.py | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 mlir/python/mlir/__init__.py diff --git a/mlir/python/mlir/__init__.py b/mlir/python/mlir/__init__.py deleted file mode 100644 index a0f4e3d5e..000000000 --- a/mlir/python/mlir/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception From bc531731b9a08722efa82540abf5d3c474136ab6 Mon Sep 17 00:00:00 2001 From: Daniel Resnick Date: Wed, 16 Mar 2022 16:31:08 -0600 Subject: [PATCH 271/915] [mlir][capi] Add external pass creation to MLIR C-API Adds the ability to create external passes using the C-API. This allows passes to be written in C or languages that use the C-bindings. Differential Revision: https://reviews.llvm.org/D121866 --- mlir/include/mlir-c/IR.h | 14 ----- mlir/include/mlir-c/Pass.h | 55 +++++++++++++++++++ mlir/include/mlir-c/Support.h | 43 +++++++++++++++ mlir/include/mlir/CAPI/IR.h | 1 - mlir/include/mlir/CAPI/Support.h | 5 ++ mlir/lib/CAPI/IR/IR.cpp | 12 ----- mlir/lib/CAPI/IR/Pass.cpp | 91 ++++++++++++++++++++++++++++++++ mlir/lib/CAPI/IR/Support.cpp | 39 +++++++++++++- 8 files changed, 232 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d99955466..1c0517f59 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -62,7 +62,6 @@ DEFINE_C_API_STRUCT(MlirIdentifier, const void); DEFINE_C_API_STRUCT(MlirLocation, const void); DEFINE_C_API_STRUCT(MlirModule, const void); DEFINE_C_API_STRUCT(MlirType, const void); -DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirValue, const void); #undef DEFINE_C_API_STRUCT @@ -757,19 +756,6 @@ MLIR_CAPI_EXPORTED bool mlirIdentifierEqual(MlirIdentifier ident, /// Gets the string value of the identifier. MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); -//===----------------------------------------------------------------------===// -// TypeID API. -//===----------------------------------------------------------------------===// - -/// Checks whether a type id is null. -static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } - -/// Checks if two type ids are equal. -MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); - -/// Returns the hash value of the type id. -MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); - //===----------------------------------------------------------------------===// // Symbol and SymbolTable API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index d8b216812..cdb947bde 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -15,6 +15,7 @@ #define MLIR_C_PASS_H #include "mlir-c/IR.h" +#include "mlir-c/Registration.h" #include "mlir-c/Support.h" #ifdef __cplusplus @@ -41,11 +42,16 @@ extern "C" { typedef struct name name DEFINE_C_API_STRUCT(MlirPass, void); +DEFINE_C_API_STRUCT(MlirExternalPass, void); DEFINE_C_API_STRUCT(MlirPassManager, void); DEFINE_C_API_STRUCT(MlirOpPassManager, void); #undef DEFINE_C_API_STRUCT +//===----------------------------------------------------------------------===// +// PassManager/OpPassManager APIs. +//===----------------------------------------------------------------------===// + /// Create a new top-level PassManager. MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); @@ -112,6 +118,55 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager, MLIR_CAPI_EXPORTED MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +//===----------------------------------------------------------------------===// +// External Pass API. +// +// This API allows to define passes outside of MLIR, not necessarily in +// C++, and register them with the MLIR pass management infrastructure. +// +//===----------------------------------------------------------------------===// + +/// Structure of external `MlirPass` callbacks. +/// All callbacks are required to be set unless otherwise specified. +struct MlirExternalPassCallbacks { + /// This callback is called from the pass is created. + /// This is analogous to a C++ pass constructor. + void (*construct)(void *userData); + + /// This callback is called when the pass is destroyed + /// This is analogous to a C++ pass destructor. + void (*destruct)(void *userData); + + /// This callback is optional. + /// The callback is called before the pass is run, allowing a chance to + /// initialize any complex state necessary for running the pass. + /// See Pass::initialize(MLIRContext *). + MlirLogicalResult (*initialize)(MlirContext ctx, void *userData); + + /// This callback is called when the pass is cloned. + /// See Pass::clonePass(). + void *(*clone)(void *userData); + + /// This callback is called when the pass is run. + /// See Pass::runOnOperation(). + void (*run)(MlirOperation op, MlirExternalPass pass, void *userData); +}; +typedef struct MlirExternalPassCallbacks MlirExternalPassCallbacks; + +/// Creates an external `MlirPass` that calls the supplied `callbacks` using the +/// supplied `userData`. If `opName` is empty, the pass is a generic operation +/// pass. Otherwise it is an operation pass specific to the specified pass name. +MLIR_CAPI_EXPORTED MlirPass mlirCreateExternalPass( + MlirTypeID passID, MlirStringRef name, MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData); + +/// This signals that the pass has failed. This is only valid to call during +/// the `run` callback of `MlirExternalPassCallbacks`. +/// See Pass::signalPassFailure(). +MLIR_CAPI_EXPORTED void mlirExternalPassSignalFailure(MlirExternalPass pass); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index f20e58fe6..5d20fb78d 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -50,6 +50,17 @@ extern "C" { #endif +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirTypeID, const void); +DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void); + +#undef DEFINE_C_API_STRUCT + //===----------------------------------------------------------------------===// // MlirStringRef. //===----------------------------------------------------------------------===// @@ -127,6 +138,38 @@ inline static MlirLogicalResult mlirLogicalResultFailure() { return res; } +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +/// `ptr` must be 8 byte aligned and unique to a type valid for the duration of +/// the returned type id's usage +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeIDCreate(const void *ptr); + +/// Checks whether a type id is null. +static inline bool mlirTypeIDIsNull(MlirTypeID typeID) { return !typeID.ptr; } + +/// Checks if two type ids are equal. +MLIR_CAPI_EXPORTED bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2); + +/// Returns the hash value of the type id. +MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); + +//===----------------------------------------------------------------------===// +// TypeIDAllocator API. +//===----------------------------------------------------------------------===// + +/// Creates a type id allocator for dynamic type id creation +MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate(); + +/// Deallocates the allocator and all allocated type ids +MLIR_CAPI_EXPORTED void +mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator); + +/// Allocates a type id that is valid for the lifetime of the allocator +MLIR_CAPI_EXPORTED MlirTypeID +mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 06cf7762a..899b41167 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -34,7 +34,6 @@ DEFINE_C_API_METHODS(MlirIdentifier, mlir::StringAttr) DEFINE_C_API_METHODS(MlirLocation, mlir::Location) DEFINE_C_API_METHODS(MlirModule, mlir::ModuleOp) DEFINE_C_API_METHODS(MlirType, mlir::Type) -DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_METHODS(MlirValue, mlir::Value) #endif // MLIR_CAPI_IR_H diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index 6d9a59abf..f3e8a67e0 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -16,7 +16,9 @@ #define MLIR_CAPI_SUPPORT_H #include "mlir-c/Support.h" +#include "mlir/CAPI/Wrap.h" #include "mlir/Support/LogicalResult.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" /// Converts a StringRef into its MLIR C API equivalent. @@ -39,4 +41,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) { return mlir::success(mlirLogicalResultIsSuccess(res)); } +DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) +DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) + #endif // MLIR_CAPI_SUPPORT_H diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 75ac93ffb..527aa4eaf 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -787,18 +787,6 @@ MlirStringRef mlirIdentifierStr(MlirIdentifier ident) { return wrap(unwrap(ident).strref()); } -//===----------------------------------------------------------------------===// -// TypeID API. -//===----------------------------------------------------------------------===// - -bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { - return unwrap(typeID1) == unwrap(typeID2); -} - -size_t mlirTypeIDHashValue(MlirTypeID typeID) { - return hash_value(unwrap(typeID)); -} - //===----------------------------------------------------------------------===// // Symbol and SymbolTable API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 4bfc9d013..a2998939a 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -77,3 +77,94 @@ MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, // stream and redirect to a diagnostic. return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); } + +//===----------------------------------------------------------------------===// +// External Pass API. +//===----------------------------------------------------------------------===// + +namespace mlir { +class ExternalPass; +} // namespace mlir +DEFINE_C_API_PTR_METHODS(MlirExternalPass, mlir::ExternalPass) + +namespace mlir { +/// This pass class wraps external passes defined in other languages using the +/// MLIR C-interface +class ExternalPass : public Pass { +public: + ExternalPass(TypeID passID, StringRef name, StringRef argument, + StringRef description, Optional opName, + ArrayRef dependentDialects, + MlirExternalPassCallbacks callbacks, void *userData) + : Pass(passID, opName), id(passID), name(name), argument(argument), + description(description), dependentDialects(dependentDialects), + callbacks(callbacks), userData(userData) { + callbacks.construct(userData); + } + + ~ExternalPass() override { callbacks.destruct(userData); } + + StringRef getName() const override { return name; } + StringRef getArgument() const override { return argument; } + StringRef getDescription() const override { return description; } + + void getDependentDialects(DialectRegistry ®istry) const override { + MlirDialectRegistry cRegistry = wrap(®istry); + for (MlirDialectHandle dialect : dependentDialects) + mlirDialectHandleInsertDialect(dialect, cRegistry); + } + + void signalPassFailure() { Pass::signalPassFailure(); } + +protected: + LogicalResult initialize(MLIRContext *ctx) override { + if (callbacks.initialize) + return unwrap(callbacks.initialize(wrap(ctx), userData)); + return success(); + } + + bool canScheduleOn(RegisteredOperationName opName) const override { + if (Optional specifiedOpName = getOpName()) + return opName.getStringRef() == specifiedOpName; + return true; + } + + void runOnOperation() override { + callbacks.run(wrap(getOperation()), wrap(this), userData); + } + + std::unique_ptr clonePass() const override { + void *clonedUserData = callbacks.clone(userData); + return std::make_unique(id, name, argument, description, + getOpName(), dependentDialects, + callbacks, clonedUserData); + } + +private: + TypeID id; + std::string name; + std::string argument; + std::string description; + std::vector dependentDialects; + MlirExternalPassCallbacks callbacks; + void *userData; +}; +} // namespace mlir + +MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, + MlirStringRef argument, + MlirStringRef description, MlirStringRef opName, + intptr_t nDependentDialects, + MlirDialectHandle *dependentDialects, + MlirExternalPassCallbacks callbacks, + void *userData) { + return wrap(static_cast(new mlir::ExternalPass( + unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), + opName.length > 0 ? Optional(unwrap(opName)) : None, + {dependentDialects, static_cast(nDependentDialects)}, callbacks, + userData))); +} + +void mlirExternalPassSignalFailure(MlirExternalPass pass) { + unwrap(pass)->signalPassFailure(); +} diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index b6e1f9180..cbfbb5476 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Support.h" +#include "mlir/CAPI/Support.h" #include "llvm/ADT/StringRef.h" #include @@ -19,3 +19,40 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { return llvm::StringRef(string.data, string.length) == llvm::StringRef(other.data, other.length); } + +//===----------------------------------------------------------------------===// +// TypeID API. +//===----------------------------------------------------------------------===// + +MlirTypeID mlirTypeIDCreate(const void *ptr) { + assert(reinterpret_cast(ptr) % 8 == 0 && + "ptr must be 8 byte aligned"); + // This is essentially a no-op that returns back `ptr`, but by going through + // the `TypeID` functions we can get compiler errors in case the `TypeID` + // api/representation changes + return wrap(mlir::TypeID::getFromOpaquePointer(ptr)); +} + +bool mlirTypeIDEqual(MlirTypeID typeID1, MlirTypeID typeID2) { + return unwrap(typeID1) == unwrap(typeID2); +} + +size_t mlirTypeIDHashValue(MlirTypeID typeID) { + return hash_value(unwrap(typeID)); +} + +//===----------------------------------------------------------------------===// +// TypeIDAllocator API. +//===----------------------------------------------------------------------===// + +MlirTypeIDAllocator mlirTypeIDAllocatorCreate() { + return wrap(new mlir::TypeIDAllocator()); +} + +void mlirTypeIDAllocatorDestroy(MlirTypeIDAllocator allocator) { + delete unwrap(allocator); +} + +MlirTypeID mlirTypeIDAllocatorAllocateTypeID(MlirTypeIDAllocator allocator) { + return wrap(unwrap(allocator)->allocate()); +} From 2fa5f91dabf1536a09d1fd95fa9ca2f3905f7837 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 4 Apr 2022 18:51:35 +0200 Subject: [PATCH 272/915] [mlir][capi] Unbreak Interfaces CAPI after bc531731b9 No idea why check-mlir doesn't build this. --- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index f752a57b5..3f1155fe1 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -9,6 +9,7 @@ #include "mlir-c/Interfaces.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/ScopeExit.h" From de0c8c30801aa44354e4882bd1d19d04798b1dfd Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 5 Apr 2022 17:10:20 -0700 Subject: [PATCH 273/915] [MLIR] [Python] Pybind adaptors: coerce None to default MlirLocation Add default source location coercion to enable location elision in Python code. --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 661ed48f9..582855a33 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -125,11 +125,16 @@ struct type_caster { }; /// Casts object <-> MlirLocation. -// TODO: Coerce None to default MlirLocation. template <> struct type_caster { PYBIND11_TYPE_CASTER(MlirLocation, _("MlirLocation")); bool load(handle src, bool) { + if (src.is_none()) { + // Gets the current thread-bound context. + src = py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } py::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToLocation(capsule.ptr()); return !mlirLocationIsNull(value); From cbfc0968de247bdcabc1bb7f61a6f1dc39febd5a Mon Sep 17 00:00:00 2001 From: John Demme Date: Wed, 6 Apr 2022 10:06:30 -0700 Subject: [PATCH 274/915] [MLIR] Add block detach func to CAPI and use it in Python bindings Adds `mlirBlockDetach` to the CAPI to remove a block from its parent region. Use it in the Python bindings to implement `Block.append_to(region)`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D123165 --- mlir/include/mlir-c/IR.h | 3 +++ mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/CAPI/IR/IR.cpp | 5 +++++ 3 files changed, 17 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1c0517f59..34c6dd678 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -558,6 +558,9 @@ MLIR_CAPI_EXPORTED MlirBlock mlirBlockCreate(intptr_t nArgs, /// Takes a block owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirBlockDestroy(MlirBlock block); +/// Detach a block from the owning region and assume ownership. +MLIR_CAPI_EXPORTED void mlirBlockDetach(MlirBlock block); + /// Checks whether a block is null. static inline bool mlirBlockIsNull(MlirBlock block) { return !block.ptr; } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1225c2648..25391ebd0 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2755,6 +2755,15 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("parent"), py::arg("arg_types") = py::list(), "Creates and returns a new Block at the beginning of the given " "region (with given argument types).") + .def( + "append_to", + [](PyBlock &self, PyRegion ®ion) { + MlirBlock b = self.get(); + if (!mlirRegionIsNull(mlirBlockGetParentRegion(b))) + mlirBlockDetach(b); + mlirRegionAppendOwnedBlock(region.get(), b); + }, + "Append this block to a region, transferring ownership if necessary") .def( "create_before", [](PyBlock &self, py::args pyArgTypes) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 527aa4eaf..6959ad322 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -634,6 +634,11 @@ void mlirBlockInsertOwnedOperationBefore(MlirBlock block, void mlirBlockDestroy(MlirBlock block) { delete unwrap(block); } +void mlirBlockDetach(MlirBlock block) { + Block *b = unwrap(block); + b->getParent()->getBlocks().remove(b); +} + intptr_t mlirBlockGetNumArguments(MlirBlock block) { return static_cast(unwrap(block)->getNumArguments()); } From 0e7b720f7d51db363bb620148bbbf40fc9d722d7 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 8 Apr 2022 17:56:26 +0530 Subject: [PATCH 275/915] [mlir][Linalg] Add pooling_nchw_sum op. This commit adds pooling_nchw_sum as a yaml op. Reviewed By: cathyzhyi, gysit Differential Revision: https://reviews.llvm.org/D123013 --- .../linalg/opdsl/ops/core_named_ops.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 2c6291bad..023a95dc7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -522,6 +522,10 @@ def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs sum pooling. + Layout: + * Input: NHWC. + * Kernel: HW. + Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ @@ -531,6 +535,28 @@ def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) +@linalg_structured_op +def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): + """Performs sum pooling. + + Layout: + * Input: NCHW. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) + + @linalg_structured_op def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), From 3b4fb268d48abf769515990f7a07d69f87578a72 Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Fri, 8 Apr 2022 15:18:16 -0700 Subject: [PATCH 276/915] [mlir] Remove uses of LLVM's legacy pass manager Use the new pass manager. This also removes the ability to run arbitrary sets of passes. Not sure if this functionality is used, but it doesn't seem to be tested. No need to initialize passes outside of constructing the PassBuilder with the new pass manager. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D123425 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 4c5553087..2190566b2 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -48,8 +48,8 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, // Create a transformer to run all LLVM optimization passes at the // specified optimization level. auto llvmOptLevel = static_cast(optLevel); - auto transformer = mlir::makeLLVMPassesTransformer( - /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); + auto transformer = mlir::makeOptimizingTransformer( + llvmOptLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); ExecutionEngineOptions jitOptions; jitOptions.transformer = transformer; jitOptions.jitCodeGenOptLevel = llvmOptLevel; From 126329828d39588fcc7f7cd1fa5e88c9e86ed16a Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Mon, 11 Apr 2022 16:45:19 -0700 Subject: [PATCH 277/915] Revert "[mlir] Remove uses of LLVM's legacy pass manager" This reverts commit 3b4fb268d48abf769515990f7a07d69f87578a72. Causes test failures: https://lab.llvm.org/buildbot#builders/61/builds/24879 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 2190566b2..4c5553087 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -48,8 +48,8 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, // Create a transformer to run all LLVM optimization passes at the // specified optimization level. auto llvmOptLevel = static_cast(optLevel); - auto transformer = mlir::makeOptimizingTransformer( - llvmOptLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); + auto transformer = mlir::makeLLVMPassesTransformer( + /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); ExecutionEngineOptions jitOptions; jitOptions.transformer = transformer; jitOptions.jitCodeGenOptLevel = llvmOptLevel; From 00e46e393d2587c6c2e0b0076eb5b05f3818e07e Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Fri, 8 Apr 2022 15:18:16 -0700 Subject: [PATCH 278/915] Reland [mlir] Remove uses of LLVM's legacy pass manager Use the new pass manager. This also removes the ability to run arbitrary sets of passes. Not sure if this functionality is used, but it doesn't seem to be tested. No need to initialize passes outside of constructing the PassBuilder with the new pass manager. Reland: Fixed custom calls to `-lower-matrix-intrinsics` in integration tests by replacing them with `-O0 -enable-matrix`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D123425 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 4c5553087..2190566b2 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -48,8 +48,8 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, // Create a transformer to run all LLVM optimization passes at the // specified optimization level. auto llvmOptLevel = static_cast(optLevel); - auto transformer = mlir::makeLLVMPassesTransformer( - /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get()); + auto transformer = mlir::makeOptimizingTransformer( + llvmOptLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); ExecutionEngineOptions jitOptions; jitOptions.transformer = transformer; jitOptions.jitCodeGenOptLevel = llvmOptLevel; From cc281e8d4a4739deec403da641c96748590f6e05 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Mon, 11 Apr 2022 11:03:29 +0000 Subject: [PATCH 279/915] [mlir] Prefix pass manager options with `mlir-` With this change, there's going to be a clear distinction between LLVM and MLIR pass maanger options (e.g. `-mlir-print-after-all` vs `-print-after-all`). This change is desirable from the point of view of projects that depend on both LLVM and MLIR, e.g. Flang. For consistency, all pass manager options in MLIR are prefixed with `mlir-`, even options that don't have equivalents in LLVM . Differential Revision: https://reviews.llvm.org/D123495 --- mlir/include/mlir-c/IR.h | 2 +- mlir/include/mlir-c/Pass.h | 2 +- mlir/lib/Bindings/Python/Pass.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 34c6dd678..edd0e44ec 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -126,7 +126,7 @@ mlirContextGetNumLoadedDialects(MlirContext context); MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MlirStringRef name); -/// Set threading mode (must be set to false to print-ir-after-all). +/// Set threading mode (must be set to false to mlir-print-ir-after-all). MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable); diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index cdb947bde..cb6112954 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -71,7 +71,7 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, MlirModule module); -/// Enable print-ir-after-all. +/// Enable mlir-print-ir-after-all. MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting(MlirPassManager passManager); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index dba2231a1..3278d3a91 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -73,7 +73,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager) { mlirPassManagerEnableIRPrinting(passManager.get()); }, - "Enable print-ir-after-all.") + "Enable mlir-print-ir-after-all.") .def( "enable_verifier", [](PyPassManager &passManager, bool enable) { From 6b300fb76aadab8b3f8ea07e96dfa6eb736bddc0 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Tue, 19 Apr 2022 18:53:29 +0000 Subject: [PATCH 280/915] [MLIR] Add function to create BFloat16 array attribute This patch adds a new function `mlirDenseElementsAttrBFloat16Get()`, which accepts the shaped type, the number of BFloat16 values, and a pointer to an array of BFloat16 values, each of which is a `uint16_t` value. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D123981 --- mlir/include/mlir-c/BuiltinAttributes.h | 2 ++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index bb4431f7b..ce4514094 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -379,6 +379,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloatGet( MlirType shapedType, intptr_t numElements, const float *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet( MlirType shapedType, intptr_t numElements, const double *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); /// Creates a dense elements attribute with the given shaped type from string /// elements. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 9ea277b74..aa498b2c1 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -474,6 +474,13 @@ MlirAttribute mlirDenseElementsAttrDoubleGet(MlirType shapedType, const double *elements) { return getDenseAttribute(shapedType, numElements, elements); } +MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, From 2adce92e88edb6aa0cb5084726a92ed34a93f142 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 19 Apr 2022 15:03:15 -0700 Subject: [PATCH 281/915] [MLIR] [Python] Add a method to clear live operations map Introduce a method on PyMlirContext (and plumb it through to Python) to invalidate all of the operations in the live operations map and clear it. Since Python has no notion of private data, an end-developer could reach into some 3rd party API which uses the MLIR Python API (that is behaving correctly with regard to holding references) and grab a reference to an MLIR Python Operation, preventing it from being deconstructed out of the live operations map. This allows the API developer to clear the map when it calls C++ code which could delete operations, protecting itself from its users. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D123895 --- mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/Bindings/Python/IRModule.h | 9 +++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 25391ebd0..d1877a11b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -505,6 +505,14 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +size_t PyMlirContext::clearLiveOperations() { + for (auto &op : liveOperations) + op.second.second->setInvalid(); + size_t numInvalidated = liveOperations.size(); + liveOperations.clear(); + return numInvalidated; +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { @@ -2208,6 +2216,7 @@ void mlir::python::populateIRCore(py::module &m) { return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 2046ce0c1..371157a56 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -201,6 +201,12 @@ class PyMlirContext { /// Used for testing. size_t getLiveOperationCount(); + /// Clears the live operations map, returning the number of entries which were + /// invalidated. To be used as a safety mechanism so that API end-users can't + /// corrupt by holding references they shouldn't have accessed in the first + /// place. + size_t clearLiveOperations(); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -575,6 +581,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// parent context's live operations map, and sets the valid bit false. void erase(); + /// Invalidate the operation. + void setInvalid() { valid = false; } + /// Clones this operation. pybind11::object clone(const pybind11::object &ip); From bd8284961a26b69accdbfa0fd96c397c3a01d300 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 26 Apr 2022 11:00:35 -0700 Subject: [PATCH 282/915] [mlir:ODS] Support using attributes in AllTypesMatch to automatically add InferTypeOpInterface This allows for using attribute types in result type inference for use with InferTypeOpInterface. This was a TODO before, but it isn't much additional work to properly support this. After this commit, arith::ConstantOp can now have its InferTypeOpInterface implementation automatically generated. Differential Revision: https://reviews.llvm.org/D124580 --- mlir/python/mlir/dialects/_arith_ops_ext.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py index e35f5f2a4..c755df255 100644 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -41,11 +41,11 @@ def __init__(self, loc=None, ip=None): if isinstance(value, int): - super().__init__(result, IntegerAttr.get(result, value), loc=loc, ip=ip) + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): - super().__init__(result, FloatAttr.get(result, value), loc=loc, ip=ip) + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) else: - super().__init__(result, value, loc=loc, ip=ip) + super().__init__(value, loc=loc, ip=ip) @classmethod def create_index(cls, value: int, *, loc=None, ip=None): From 5cb79fc140d89ced2ec03aa826922a6840824cd7 Mon Sep 17 00:00:00 2001 From: Stella Stamenova Date: Mon, 9 May 2022 19:46:47 -0700 Subject: [PATCH 283/915] [mlir] Fix python bindings build on Windows in Debug Currently, building mlir with the python bindings enabled on Windows in Debug is broken because pybind11, python and cmake don't like to play together. This change normalizes how the three interact, so that the builds can now run and succeed. The main issue is that python and cmake both make assumptions about which libraries are needed in a Windows build based on the flavor. - cmake assumes that a debug (or a debug-like) flavor of the build will always require pythonX_d.lib and provides no option/hint to tell it to use a different library. cmake does find both the debug and release versions, but then uses the debug library. - python (specifically pyconfig.h and by extension python.h) hardcodes the dependency on pythonX_d.lib or pythonX.lib depending on whether `_DEBUG` is defined. This is NOT transparent - it does not show up anywhere in the build logs until the link step fails with `pythonX_d.lib is missing` (or `pythonX.lib is missing`) - pybind11 tries to "fix" this by implementing a workaround - unless Py_DEBUG is defined, `_DEBUG` is explicitly undefined right before including python headers. This also requires some windows headers to be included differently, so while clever, this is a non-trivial workaround. mlir itself includes the pybind11 headers (which contain the workaround) AS WELL AS python.h, essentially always requiring both pythonX.lib and pythonX_d.lib for linking. cmake explicitly only adds one or the other, so the build fails. This change does a couple of things: - In the cmake files, explicitly add the release version of the python library on Windows builds regardless of flavor. Since Py_DEBUG is not defined, pybind11 will always require release and it will be satisfied - To satisfy python as well, this change removes any explicit inclusions of Python.h on Windows instead relying on the fact that pybind11 headers will bring in what is needed There are a few additional things that we could do but I rejected as unnecessary at this time: - define Py_DEBUG based on the CMAKE_BUILD_TYPE - this will *mostly* work, we'd have to think through multiconfig generators like VS, but it's possible. There doesn't seem to be a need to link against debug python at the moment, so I chose not to overcomplicate the build and always default to release - similar to above, but define Py_DEBUG based on the CMAKE_BUILD_TYPE *as well as* the presence of the debug python library (`Python3_LIBRARY_DEBUG`). Similar to above, this seems unnecessary right now. I think it's slightly better than above because most people don't actually have the debug version of python installed, so this would prevent breaks in that case. - similar to the two above, but add a cmake variable to control the logic - implement the pybind11 workaround directly in mlir (specifically in Interop.h) so that Python.h can still be included directly. This seems prone to error and a pain to maintain in lock step with pybind11 - reorganize how the pybind11 headers are included and place at least one of them in Interop.h directly, so that the header has all of its dependencies included as was the original intention. I decided against this because it really doesn't need pybind11 logic and it's always included after pybind11 is, so we don't necessarily need the python includes Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D125284 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 10 ++++++++++ mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 1 - mlir/lib/Bindings/Python/IRCore.cpp | 1 - 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 7efe8500a..86b57cd73 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -21,7 +21,17 @@ #ifndef MLIR_C_BINDINGS_PYTHON_INTEROP_H #define MLIR_C_BINDINGS_PYTHON_INTEROP_H +// We *should*, in theory, include Python.h here in order to import the correct +// definitions for what we need below, however, importing Python.h directly on +// Windows results in the enforcement of either pythonX.lib or pythonX_d.lib +// depending on the build flavor. Instead, we rely on the fact that this file +// (Interop.h) is always included AFTER pybind11 and will therefore have access +// to the definitions from Python.h in addition to having a workaround applied +// through the pybind11 headers that allows us to control which python library +// is used. +#if !defined(_MSC_VER) #include +#endif #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 901690018..f5179bd7c 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/ExecutionEngine.h" #include "mlir/Bindings/Python/PybindAdaptors.h" diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d1877a11b..37ca3b953 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -19,7 +19,6 @@ #include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include #include From bdad8f784f8441a99a7d94673f397d0d7d7452bf Mon Sep 17 00:00:00 2001 From: Chris Lattner Date: Thu, 12 May 2022 05:32:16 +0100 Subject: [PATCH 284/915] [DenseElementAttr] Simplify the public API for creating these. Instead of requiring the client to compute the "isSplat" bit, compute it internally. This makes the logic more consistent and defines away a lot of "elements.size()==1" in the clients. This addresses Issue #55185 Differential Revision: https://reviews.llvm.org/D125447 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index aa498b2c1..759b70895 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -348,11 +348,9 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, rawBufferSize); bool isSplat = false; if (!DenseElementsAttr::isValidRawBuffer(shapedTypeCpp, rawBufferCpp, - isSplat)) { + isSplat)) return mlirAttributeGetNull(); - } - return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp, - isSplat)); + return wrap(DenseElementsAttr::getFromRawBuffer(shapedTypeCpp, rawBufferCpp)); } MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, From 61d058a8d6997ec272c1423f460ec451d52087a8 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 17 May 2022 22:42:39 -0700 Subject: [PATCH 285/915] [mlir][python] Add Python bindings for ml_program dialect. Differential Revision: https://reviews.llvm.org/D125852 --- mlir/python/CMakeLists.txt | 9 ++ mlir/python/mlir/dialects/MLProgramOps.td | 15 +++ .../mlir/dialects/_ml_program_ops_ext.py | 116 ++++++++++++++++++ mlir/python/mlir/dialects/ml_program.py | 5 + 4 files changed, 145 insertions(+) create mode 100644 mlir/python/mlir/dialects/MLProgramOps.td create mode 100644 mlir/python/mlir/dialects/_ml_program_ops_ext.py create mode 100644 mlir/python/mlir/dialects/ml_program.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1477ffea7..3e14a1974 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -132,6 +132,15 @@ declare_mlir_dialect_python_bindings( dialects/_memref_ops_ext.py DIALECT_NAME memref) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MLProgramOps.td + SOURCES + dialects/ml_program.py + dialects/_ml_program_ops_ext.py + DIALECT_NAME ml_program) + declare_mlir_python_sources( MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/MLProgramOps.td b/mlir/python/mlir/dialects/MLProgramOps.td new file mode 100644 index 000000000..5ac45ca1b --- /dev/null +++ b/mlir/python/mlir/dialects/MLProgramOps.td @@ -0,0 +1,15 @@ +//===-- MLProgramOps.td - Entry point for MLProgramOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MLPROGRAM_OPS +#define PYTHON_BINDINGS_MLPROGRAM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/MLProgram/IR/MLProgramOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py new file mode 100644 index 000000000..a3df7ff03 --- /dev/null +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -0,0 +1,116 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from typing import Union + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from ._ml_program_ops_gen import * + + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +class FuncOp: + """Specialization for the func op class.""" + + def __init__(self, + name, + type, + *, + visibility=None, + body_builder=None, + loc=None, + ip=None): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = StringAttr.get( + str(visibility)) if visibility is not None else None + super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError('External function does not have a body') + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError('The function already has an entry block!') + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py new file mode 100644 index 000000000..a654529b4 --- /dev/null +++ b/mlir/python/mlir/dialects/ml_program.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._ml_program_ops_gen import * From a79d9847037a6231d0516db5c045c1b46438c1d0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Sat, 21 May 2022 02:31:07 +0200 Subject: [PATCH 286/915] [mlir][bufferization] Add bufferization.alloc_tensor op This change adds a new op `alloc_tensor` to the bufferization dialect. During bufferization, this op is always lowered to a buffer allocation (unless it is "eliminated" by a pre-processing pass). It is useful to have such an op in tensor land, because it allows users to model tensor SSA use-def chains (which drive bufferization decisions) and because tensor SSA use-def chains can be analyzed by One-Shot Bufferize, while memref values cannot. This change also replaces all uses of linalg.init_tensor in bufferization-related code with bufferization.alloc_tensor. linalg.init_tensor and bufferization.alloc_tensor are similar, but the purpose of the former one is just to carry a shape. It does not indicate a memory allocation. linalg.init_tensor is not suitable for modelling SSA use-def chains for bufferization purposes, because linalg.init_tensor is marked as not having side effects (in contrast to alloc_tensor). As such, it is legal to move linalg.init_tensor ops around/CSE them/etc. This is not desirable for alloc_tensor; it represents an explicit buffer allocation while still in tensor land and such allocations should not suddenly disappear or get moved around when running the canonicalizer/CSE/etc. BEGIN_PUBLIC No public commit message needed for presubmit. END_PUBLIC Differential Revision: https://reviews.llvm.org/D126003 --- mlir/python/mlir/dialects/BufferizationOps.td | 15 ++++++ .../mlir/dialects/_bufferization_ops_ext.py | 51 +++++++++++++++++++ 2 files changed, 66 insertions(+) create mode 100644 mlir/python/mlir/dialects/BufferizationOps.td create mode 100644 mlir/python/mlir/dialects/_bufferization_ops_ext.py diff --git a/mlir/python/mlir/dialects/BufferizationOps.td b/mlir/python/mlir/dialects/BufferizationOps.td new file mode 100644 index 000000000..c5170cee3 --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationOps.td @@ -0,0 +1,15 @@ +//===-- BufferizationOps.td - Entry point for BufferizationOps bindings ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_OPS +#define PYTHON_BINDINGS_BUFFERIZATION_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Bufferization/IR/BufferizationOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py new file mode 100644 index 000000000..2414c8b74 --- /dev/null +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -0,0 +1,51 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from typing import Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +class AllocTensorOp: + """Extends the bufferization.alloc_tensor op.""" + + def __init__(self, + sizes: Union[Sequence[int], Sequence[Value]], + element_type: Type, + *, + loc=None, + ip=None): + """Constructs an `alloc_tensor` with either static or dynamic sizes.""" + context = get_default_loc_context(loc) + operands = [] + attributes = {} + # TODO: Refactor the AllocTensorOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + if sizes and isinstance(sizes[0], Value): + # Dynamic sizes. + operands.extend(sizes) + static_size_ints = [-1] * len(sizes) + result_type = RankedTensorType.get(static_size_ints, element_type) + else: + # Static sizes. + result_type = RankedTensorType.get(sizes, element_type) + static_size_ints = sizes + + i64_type = IntegerType.get_signless(64) + attributes["static_sizes"] = ArrayAttr.get( + [IntegerAttr.get(i64_type, s) for s in static_size_ints], + context=context) + op = self.build_generic( + results=[result_type], + operands=operands, + attributes=attributes, + loc=loc, + ip=ip) + OpView.__init__(self, op) From a68c7752cb43814b40f59b6d5cdc2d95b2180748 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Sat, 21 May 2022 13:25:24 +0200 Subject: [PATCH 287/915] [mlir] Move diagnostic handlers instead of copying This also allows using unique_ptr instead of shared_ptr for the CAPI user data. NFCI. --- mlir/lib/CAPI/IR/Diagnostics.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/IR/Diagnostics.cpp b/mlir/lib/CAPI/IR/Diagnostics.cpp index 40639c7ba..4a13ae576 100644 --- a/mlir/lib/CAPI/IR/Diagnostics.cpp +++ b/mlir/lib/CAPI/IR/Diagnostics.cpp @@ -59,11 +59,12 @@ MlirDiagnosticHandlerID mlirContextAttachDiagnosticHandler( assert(handler && "unexpected null diagnostic handler"); if (deleteUserData == nullptr) deleteUserData = deleteUserDataNoop; - std::shared_ptr sharedUserData(userData, deleteUserData); DiagnosticEngine::HandlerID id = unwrap(context)->getDiagEngine().registerHandler( - [handler, sharedUserData](Diagnostic &diagnostic) { - return unwrap(handler(wrap(diagnostic), sharedUserData.get())); + [handler, + ownedUserData = std::unique_ptr( + userData, deleteUserData)](Diagnostic &diagnostic) { + return unwrap(handler(wrap(diagnostic), ownedUserData.get())); }); return static_cast(id); } From e089a5bbd31baecca50a729937aec13c5a27c2b8 Mon Sep 17 00:00:00 2001 From: Jeremy Furtek Date: Sat, 21 May 2022 21:12:11 -0700 Subject: [PATCH 288/915] [mlir][tblgen][ods][python] Use keyword-only arguments for optional builder arguments in generated Python bindings This diff modifies `mlir-tblgen` to generate Python Operation class `__init__()` functions that use Python keyword-only arguments. Previously, all `__init__()` function arguments were positional. Python code to create MLIR Operations was required to provide values for ALL builder arguments, including optional arguments (attributes and operands). Callers that did not provide, for example, an optional attribute would be forced to provide `None` as an argument for EACH optional attribute. Proposed changes in this diff use `tblgen` record information (as provided by ODS) to generate keyword arguments for: - optional operands - optional attributes (which includes unit attributes) - default-valued attributes These `__init__()` function keyword arguments have default `None` values (i.e. the argument form is `optionalAttr=None`), allowing callers to create Operations more easily. Note that since optional arguments become keyword-only arguments (since they are placed after the bare `*` argument), this diff will require ALL optional operands and attributes to be provided using explicit keyword syntax. This may, in the short term, break any out-of-tree Python code that provided values via positional arguments. However, in the long term, it seems that requiring keywords for optional arguments will be more robust to operation changes that add arguments. Tests were modified to reflect the updated Operation builder calling convention. This diff partially addresses the requests made in the github issue below. https://github.com/llvm/llvm-project/issues/54932 Reviewed By: stellaraccident, mikeurbach Differential Revision: https://reviews.llvm.org/D124717 --- mlir/python/mlir/dialects/_func_ops_ext.py | 2 +- .../mlir/dialects/_ml_program_ops_ext.py | 2 +- mlir/python/mlir/dialects/_pdl_ops_ext.py | 20 +++++++++---------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 6fe3ff530..79577463d 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -58,7 +58,7 @@ def __init__(self, type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py index a3df7ff03..8db82cf81 100644 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -48,7 +48,7 @@ def __init__(self, type = TypeAttr.get(type) sym_visibility = StringAttr.get( str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility, loc=loc, ip=ip) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) if body_builder: entry_block = self.add_entry_block() with InsertionPoint(entry_block): diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index fb5b519c7..bb63fe64d 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -93,7 +93,7 @@ def __init__(self, ip=None): type = type if type is None else _get_value(type) result = pdl.AttributeType.get() - super().__init__(result, type, value, loc=loc, ip=ip) + super().__init__(result, type=type, value=value, loc=loc, ip=ip) class EraseOp: @@ -118,7 +118,7 @@ def __init__(self, ip=None): type = type if type is None else _get_value(type) result = pdl.ValueType.get() - super().__init__(result, type, loc=loc, ip=ip) + super().__init__(result, type=type, loc=loc, ip=ip) class OperandsOp: @@ -131,7 +131,7 @@ def __init__(self, ip=None): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, types, loc=loc, ip=ip) + super().__init__(result, type=types, loc=loc, ip=ip) class OperationOp: @@ -155,7 +155,7 @@ def __init__(self, attributeNames = ArrayAttr.get(attributeNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, name, args, attributeValues, attributeNames, types, loc=loc, ip=ip) + super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) class PatternOp: @@ -170,7 +170,7 @@ def __init__(self, """Creates an PDL `pattern` operation.""" name_attr = None if name is None else _get_str_attr(name) benefit_attr = _get_int_attr(16, benefit) - super().__init__(benefit_attr, name_attr, loc=loc, ip=ip) + super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip) self.regions[0].blocks.append() @property @@ -192,7 +192,7 @@ def __init__(self, op = _get_value(op) with_op = with_op if with_op is None else _get_value(with_op) with_values = _get_values(with_values) - super().__init__(op, with_op, with_values, loc=loc, ip=ip) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) class ResultOp: @@ -222,7 +222,7 @@ def __init__(self, ip=None): parent = _get_value(parent) index = index if index is None else _get_int_attr(32, index) - super().__init__(result, parent, index, loc=loc, ip=ip) + super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: @@ -238,7 +238,7 @@ def __init__(self, root = root if root is None else _get_value(root) name = name if name is None else _get_str_attr(name) args = _get_values(args) - super().__init__(root, name, args, loc=loc, ip=ip) + super().__init__(args, root=root,name=name, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" @@ -261,7 +261,7 @@ def __init__(self, ip=None): type = type if type is None else _get_type_attr(type) result = pdl.TypeType.get() - super().__init__(result, type, loc=loc, ip=ip) + super().__init__(result, type=type, loc=loc, ip=ip) class TypesOp: @@ -275,4 +275,4 @@ def __init__(self, types = _get_array_attr([_get_type_attr(ty) for ty in types]) types = None if not types else types result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, types, loc=loc, ip=ip) + super().__init__(result, types=types, loc=loc, ip=ip) From 109dac7e7de9ec8df650e721db5c272189844d42 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 23 May 2022 18:10:12 +0200 Subject: [PATCH 289/915] [mlir][bufferization] Fix Python bindings Differential Revision: https://reviews.llvm.org/D126179 --- mlir/python/CMakeLists.txt | 9 +++++ .../mlir/dialects/_bufferization_ops_ext.py | 33 ++++--------------- mlir/python/mlir/dialects/bufferization.py | 5 +++ 3 files changed, 21 insertions(+), 26 deletions(-) create mode 100644 mlir/python/mlir/dialects/bufferization.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 3e14a1974..d280bf105 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -63,6 +63,15 @@ declare_mlir_dialect_python_bindings( SOURCES_GLOB dialects/async_dialect/*.py DIALECT_NAME async_dialect) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BufferizationOps.td + SOURCES + dialects/bufferization.py + dialects/_bufferization_ops_ext.py + DIALECT_NAME bufferization) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py index 2414c8b74..c720844af 100644 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -5,7 +5,7 @@ try: from typing import Sequence, Union from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ._ods_common import get_default_loc_context from typing import Any, List, Union except ImportError as e: @@ -16,36 +16,17 @@ class AllocTensorOp: """Extends the bufferization.alloc_tensor op.""" def __init__(self, - sizes: Union[Sequence[int], Sequence[Value]], - element_type: Type, + tensor_type: Type, + dynamic_sizes: Sequence[Value], *, loc=None, ip=None): - """Constructs an `alloc_tensor` with either static or dynamic sizes.""" + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" context = get_default_loc_context(loc) - operands = [] - attributes = {} - # TODO: Refactor the AllocTensorOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - if sizes and isinstance(sizes[0], Value): - # Dynamic sizes. - operands.extend(sizes) - static_size_ints = [-1] * len(sizes) - result_type = RankedTensorType.get(static_size_ints, element_type) - else: - # Static sizes. - result_type = RankedTensorType.get(sizes, element_type) - static_size_ints = sizes - - i64_type = IntegerType.get_signless(64) - attributes["static_sizes"] = ArrayAttr.get( - [IntegerAttr.get(i64_type, s) for s in static_size_ints], - context=context) op = self.build_generic( - results=[result_type], - operands=operands, - attributes=attributes, + results=[tensor_type], + operands=dynamic_sizes, + attributes={}, loc=loc, ip=ip) OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py new file mode 100644 index 000000000..2121122f1 --- /dev/null +++ b/mlir/python/mlir/dialects/bufferization.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._bufferization_ops_gen import * From 7bb391bd57d494f0d2788c6d61a9e6d15f24067e Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 30 May 2022 15:14:02 +0200 Subject: [PATCH 290/915] [mlir] provide Python bindings for the Transform dialect Python bindings for extensions of the Transform dialect are defined in separate Python source files that can be imported on-demand, i.e., that are not imported with the "main" transform dialect. This requires a minor addition to the ODS-based bindings generator. This approach is consistent with the current model for downstream projects that are expected to bundle MLIR Python bindings: such projects can include their custom extensions into the bundle similarly to how they include their dialects. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D126208 --- mlir/python/CMakeLists.txt | 19 ++ .../dialects/LinalgStructuredTransformOps.td | 21 +++ mlir/python/mlir/dialects/TransformOps.td | 15 ++ .../dialects/_structured_transform_ops_ext.py | 178 ++++++++++++++++++ .../mlir/dialects/_transform_ops_ext.py | 106 +++++++++++ .../mlir/dialects/transform/__init__.py | 5 + .../mlir/dialects/transform/structured.py | 5 + 7 files changed, 349 insertions(+) create mode 100644 mlir/python/mlir/dialects/LinalgStructuredTransformOps.td create mode 100644 mlir/python/mlir/dialects/TransformOps.td create mode 100644 mlir/python/mlir/dialects/_structured_transform_ops_ext.py create mode 100644 mlir/python/mlir/dialects/_transform_ops_ext.py create mode 100644 mlir/python/mlir/dialects/transform/__init__.py create mode 100644 mlir/python/mlir/dialects/transform/structured.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d280bf105..17048e8cb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -116,6 +116,25 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME linalg DEPENDS LinalgOdsGen) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformOps.td + SOURCES + dialects/_transform_ops_ext.py + dialects/transform/__init__.py + DIALECT_NAME transform) + +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LinalgStructuredTransformOps.td + SOURCES + dialects/_structured_transform_ops_ext.py + dialects/transform/structured.py + DIALECT_NAME transform + EXTENSION_NAME structured_transform) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td new file mode 100644 index 000000000..a9a53fe6d --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td @@ -0,0 +1,21 @@ +//===-- LinalgStructuredTransformOps.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS +#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td" + +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/TransformOps.td b/mlir/python/mlir/dialects/TransformOps.td new file mode 100644 index 000000000..7f0d80ead --- /dev/null +++ b/mlir/python/mlir/dialects/TransformOps.td @@ -0,0 +1,15 @@ +//===-- TransformOps.td - Transform ops bind entry point ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_OPS +#define PYTHON_BINDINGS_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Transform/IR/TransformOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py new file mode 100644 index 000000000..70e39be52 --- /dev/null +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -0,0 +1,178 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Union + +IntOrAttrList = Sequence[Union[IntegerAttr, int]] +OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] + + +def _get_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: + """Creates an array attribute from its operand.""" + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + return ArrayAttr.get(values) + + +def _get_int_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]] +) -> ArrayAttr: + """Creates an integer array attribute from its operand. + + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, possibly intersperced, to + create a new array attribute containing integer attributes. Expects the + thread-local MLIR context to have been set by the context manager. + """ + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + attributes = [] + for value in values: + if isinstance(value, IntegerAttr): + attributes.append(value) + else: + attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value)) + return ArrayAttr.get(attributes) + + +def _get_int_int_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, + IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an array attribute containing array attributes of integers. + + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, potentially interpserced, to + create a new array-of-array attribute. Expects the thread-local MLIR context + to have been set by the context manager. + """ + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + + return ArrayAttr.get([_get_int_array_attr(value) for value in values]) + + +class InterchangeOp: + """Specialization for InterchangeOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + interchange_attr = _get_int_array_attr(iterator_interchange) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + iterator_interchange=interchange_attr, + loc=loc, + ip=ip) + + +class PadOp: + """Specialization for PadOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + padding_values: Optional[Union[ArrayAttr, + Sequence[Attribute]]] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + hoist_paddings: OptionalIntList = None, + transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ + ArrayAttr, IntOrAttrList]]]] = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + padding_values_attr = _get_array_attr(padding_values) + padding_dimensions_attr = _get_int_array_attr(padding_dimensions) + pack_paddings_attr = _get_int_array_attr(pack_paddings) + hoist_paddings_attr = _get_int_array_attr(hoist_paddings) + transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + padding_values=padding_values_attr, + padding_dimensions=padding_dimensions_attr, + pack_paddings=pack_paddings_attr, + hoist_paddings=hoist_paddings_attr, + transpose_paddings=transpose_paddings_attr, + loc=loc, + ip=ip) + + +class ScalarizeOp: + """Specialization for ScalarizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) + + +class TileOp: + """Specialization for TileOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + sizes: OptionalIntList = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + sizes_attr = _get_int_array_attr(sizes) + num_loops = sum( + v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + super().__init__( + pdl_operation_type, [pdl_operation_type] * num_loops, + _get_op_result_or_value(target), + sizes=sizes_attr, + interchange=_get_int_array_attr(interchange) if interchange else None, + loc=loc, + ip=ip) + + def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: + if not attr: + return [] + return [IntegerAttr(element).value for element in attr] + + +class VectorizeOp: + """Specialization for VectorizeOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + if isinstance(vectorize_padding, bool): + vectorize_padding = BoolAttr.get(vectorize_padding) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py new file mode 100644 index 000000000..138195dca --- /dev/null +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -0,0 +1,106 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Sequence, Union + + +def _get_symbol_ref_attr(value: Union[Attribute, str]): + if isinstance(value, Attribute): + return value + return FlatSymbolRefAttr.get(value) + + +class GetClosestIsolatedParentOp: + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class PDLMatchOp: + + def __init__(self, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + _get_symbol_ref_attr(pattern_name), + loc=loc, + ip=ip) + + +class SequenceOp: + + @overload + def __init__(self, resultsOrRoot: Sequence[Type], + optionalRoot: Optional[Union[Operation, Value]]): + ... + + @overload + def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], + optionalRoot: NoneType): + ... + + def __init__(self, resultsOrRoot=None, optionalRoot=None): + results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] + root = ( + resultsOrRoot + if not isinstance(resultsOrRoot, Sequence) else optionalRoot) + root = _get_op_result_or_value(root) if root else None + super().__init__(results_=results, root=root) + self.regions[0].blocks.append(pdl.OperationType.get()) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + +class WithPDLPatternsOp: + + def __init__(self, + target: Optional[Union[Operation, Value]] = None, + *, + loc=None, + ip=None): + super().__init__( + root=_get_op_result_or_value(target) if target else None, + loc=loc, + ip=ip) + self.regions[0].blocks.append(pdl.OperationType.get()) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + +class YieldOp: + + def __init__(self, + operands: Union[Operation, Sequence[Value]] = [], + *, + loc=None, + ip=None): + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py new file mode 100644 index 000000000..ab4fa5631 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py new file mode 100644 index 000000000..b8ee48c42 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._structured_transform_ops_gen import * From 599cc507d56da2e6ece3f634a4407c6802745a39 Mon Sep 17 00:00:00 2001 From: Nathaniel McVicar Date: Tue, 31 May 2022 10:03:48 -0700 Subject: [PATCH 291/915] [windows] Remove unused pybind exception params Resolve MSVC warning C4104 for unreferenced variable Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D126683 --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 1fc66fef4..3f50d3bc0 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -63,13 +63,13 @@ class PyConcreteOpInterface { : obj(std::move(object)) { try { operation = &py::cast(obj); - } catch (py::cast_error &err) { + } catch (py::cast_error &) { // Do nothing. } try { operation = &py::cast(obj).getOperation(); - } catch (py::cast_error &err) { + } catch (py::cast_error &) { // Do nothing. } @@ -86,7 +86,7 @@ class PyConcreteOpInterface { } else { try { opName = obj.attr("OPERATION_NAME").template cast(); - } catch (py::cast_error &err) { + } catch (py::cast_error &) { throw py::type_error( "Op interface does not refer to an operation or OpView class"); } From 3fb5abdd66a2b2401d9608a7d5a3778828851a6a Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 26 May 2022 15:17:31 -0700 Subject: [PATCH 292/915] [mlir][python][ctypes] fix ctype python binding complication for complex There is no direct ctypes for MLIR's complex (and thus np.complex128 and np.complex64) yet, causing the mlir python binding methods for memrefs to crash. This revision fixes this by passing complex arrays as tuples of floats, correcting at the boundaries for the proper view. NOTE: some of these changes (4 -> 2) were forced by the new "linting" Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D126422 --- mlir/python/mlir/runtime/np_to_memref.py | 175 ++++++++++++----------- 1 file changed, 92 insertions(+), 83 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 43ef95435..de5b8d6f7 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -8,112 +8,121 @@ import ctypes +class C128(ctypes.Structure): + """A ctype representation for MLIR's Double Complex.""" + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + + +class C64(ctypes.Structure): + """A ctype representation for MLIR's Float Complex.""" + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + + +def as_ctype(dtp): + """Converts dtype to ctype.""" + if dtp is np.dtype(np.complex128): + return C128 + if dtp is np.dtype(np.complex64): + return C64 + return np.ctypeslib.as_ctypes_type(dtp) + + def make_nd_memref_descriptor(rank, dtype): - class MemRefDescriptor(ctypes.Structure): - """ - Build an empty descriptor for the given rank/dtype, where rank>0. - """ - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ("shape", ctypes.c_longlong * rank), - ("strides", ctypes.c_longlong * rank), - ] + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" - return MemRefDescriptor + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] + + return MemRefDescriptor def make_zero_d_memref_descriptor(dtype): - class MemRefDescriptor(ctypes.Structure): - """ - Build an empty descriptor for the given dtype, where rank=0. - """ - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ] + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" - return MemRefDescriptor + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] + return MemRefDescriptor -class UnrankedMemRefDescriptor(ctypes.Structure): - """ Creates a ctype struct for memref descriptor""" - _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] +class UnrankedMemRefDescriptor(ctypes.Structure): + """Creates a ctype struct for memref descriptor""" + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] def get_ranked_memref_descriptor(nparray): - """ - Return a ranked memref descriptor for the given numpy array. - """ - if nparray.ndim == 0: - x = make_zero_d_memref_descriptor(np.ctypeslib.as_ctypes_type(nparray.dtype))() - x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as( - ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) - ) - x.offset = ctypes.c_longlong(0) - return x - - x = make_nd_memref_descriptor( - nparray.ndim, np.ctypeslib.as_ctypes_type(nparray.dtype) - )() + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as( - ctypes.POINTER(np.ctypeslib.as_ctypes_type(nparray.dtype)) - ) + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) x.offset = ctypes.c_longlong(0) - x.shape = nparray.ctypes.shape - - # Numpy uses byte quantities to express strides, MLIR OTOH uses the - # torch abstraction which specifies strides in terms of elements. - strides_ctype_t = ctypes.c_longlong * nparray.ndim - x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) return x + x = make_nd_memref_descriptor(nparray.ndim, ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + x.shape = nparray.ctypes.shape + + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x + def get_unranked_memref_descriptor(nparray): - """ - Return a generic/unranked memref descriptor for the given numpy array. - """ - d = UnrankedMemRefDescriptor() - d.rank = nparray.ndim - x = get_ranked_memref_descriptor(nparray) - d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) - return d + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d def unranked_memref_to_numpy(unranked_memref, np_dtype): - """ - Converts unranked memrefs to numpy arrays. - """ - descriptor = make_nd_memref_descriptor( - unranked_memref[0].rank, np.ctypeslib.as_ctypes_type(np_dtype) - ) - val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(val[0].shape), - np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, - ) - return strided_arr + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + if strided_arr.dtype == C128: + return strided_arr.view("complex128") + if strided_arr.dtype == C64: + return strided_arr.view("complex64") + return strided_arr def ranked_memref_to_numpy(ranked_memref): - """ - Converts ranked memrefs to numpy arrays. - """ - np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape - ) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(ranked_memref[0].shape), - np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, - ) - return strided_arr + """Converts ranked memrefs to numpy arrays.""" + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + if strided_arr.dtype == C128: + return strided_arr.view("complex128") + if strided_arr.dtype == C64: + return strided_arr.view("complex64") + return strided_arr From f6f3051bd5bb362d6b6460838bca6576aaf7dc81 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Tue, 31 May 2022 15:49:02 +0200 Subject: [PATCH 293/915] [mlir] add decompose and generalize to structured transform ops These ops complement the tiling/padding transformations by transforming higher-level named structured operations such as depthwise convolutions into lower-level and/or generic equivalents that are better handled by some downstream transformations. Differential Revision: https://reviews.llvm.org/D126698 --- .../dialects/_structured_transform_ops_ext.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 70e39be52..e5a2a4731 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -69,6 +69,28 @@ def _get_int_int_array_attr( return ArrayAttr.get([_get_int_array_attr(value) for value in values]) +class DecomposeOp: + """Specialization for DecomposeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class GeneralizeOp: + """Specialization for GeneralizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + class InterchangeOp: """Specialization for InterchangeOp class.""" From 126e88439d65c97c97a4fe0eba785891da8f9f7f Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Thu, 2 Jun 2022 15:11:02 -0700 Subject: [PATCH 294/915] [mlir][python][f16] add ctype python binding support for f16 Similar to complex128/complex64, float16 has no direct support in the ctypes implementation. This fixes the issue by using a custom F16 type to change the view in and out of MLIR code Reviewed By: wrengr Differential Revision: https://reviews.llvm.org/D126928 --- mlir/python/mlir/runtime/np_to_memref.py | 30 ++++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index de5b8d6f7..5b3c3c4ae 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -18,15 +18,33 @@ class C64(ctypes.Structure): _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] +class F16(ctypes.Structure): + """A ctype representation for MLIR's Float16.""" + _fields_ = [("f16", ctypes.c_int16)] + + def as_ctype(dtp): """Converts dtype to ctype.""" if dtp is np.dtype(np.complex128): return C128 if dtp is np.dtype(np.complex64): return C64 + if dtp is np.dtype(np.float16): + return F16 return np.ctypeslib.as_ctypes_type(dtp) +def to_numpy(array): + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + return array + + def make_nd_memref_descriptor(rank, dtype): class MemRefDescriptor(ctypes.Structure): @@ -105,11 +123,7 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype): np.ctypeslib.as_array(val[0].shape), np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, ) - if strided_arr.dtype == C128: - return strided_arr.view("complex128") - if strided_arr.dtype == C64: - return strided_arr.view("complex64") - return strided_arr + return to_numpy(strided_arr) def ranked_memref_to_numpy(ranked_memref): @@ -121,8 +135,4 @@ def ranked_memref_to_numpy(ranked_memref): np.ctypeslib.as_array(ranked_memref[0].shape), np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, ) - if strided_arr.dtype == C128: - return strided_arr.view("complex128") - if strided_arr.dtype == C64: - return strided_arr.view("complex64") - return strided_arr + return to_numpy(strided_arr) From 41cf88fc711af50702a45eff6539807ce9daa474 Mon Sep 17 00:00:00 2001 From: dime10 Date: Wed, 8 Jun 2022 19:50:12 +0200 Subject: [PATCH 295/915] Add Python bindings for the OpaqueType Implement the C-API and Python bindings for the builtin opaque type, which was previously missing. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D127303 --- mlir/include/mlir-c/BuiltinTypes.h | 23 +++++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 42 ++++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 22 +++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 12 +++++++ 4 files changed, 99 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 2983627a5..495591464 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -325,6 +325,29 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetInput(MlirType type, MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos); +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +/// Checks whether the given type is an opaque type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); + +/// Creates an opaque type in the given context associated with the dialect +/// identified by its namespace. The type contains opaque byte data of the +/// specified length (data need not be null-terminated). +MLIR_CAPI_EXPORTED MlirType mlirOpaqueTypeGet(MlirContext ctx, + MlirStringRef dialectNamespace, + MlirStringRef typeData); + +/// Returns the namespace of the dialect with which the given opaque type +/// is associated. The namespace string is owned by the context. +MLIR_CAPI_EXPORTED MlirStringRef +mlirOpaqueTypeGetDialectNamespace(MlirType type); + +/// Returns the raw data as a string reference. The data remains live as long as +/// the context in which the type lives. +MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueTypeGetData(MlirType type); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 380aa36d7..d93d9f66b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -608,6 +608,47 @@ class PyFunctionType : public PyConcreteType { } }; +static MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +/// Opaque Type subclass - OpaqueType. +class PyOpaqueType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::string dialectNamespace, std::string typeData, + DefaultingPyMlirContext context) { + MlirType type = mlirOpaqueTypeGet(context->get(), + toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + py::arg("dialect_namespace"), py::arg("buffer"), + py::arg("context") = py::none(), + "Create an unregistered (opaque) dialect type."); + c.def_property_readonly( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_property_readonly( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); + } +}; + } // namespace void mlir::python::populateIRTypes(py::module &m) { @@ -627,4 +668,5 @@ void mlir::python::populateIRTypes(py::module &m) { PyUnrankedMemRefType::bind(m); PyTupleType::bind(m); PyFunctionType::bind(m); + PyOpaqueType::bind(m); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 318b8eb10..446f9c4d4 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -11,6 +11,7 @@ #include "mlir-c/IR.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" @@ -357,3 +358,24 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { return wrap( unwrap(type).cast().getResult(static_cast(pos))); } + +//===----------------------------------------------------------------------===// +// Opaque type. +//===----------------------------------------------------------------------===// + +bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa(); } + +MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, + MlirStringRef typeData) { + return wrap( + OpaqueType::get(StringAttr::get(unwrap(ctx), unwrap(dialectNamespace)), + unwrap(typeData))); +} + +MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { + return wrap(unwrap(type).cast().getDialectNamespace().strref()); +} + +MlirStringRef mlirOpaqueTypeGetData(MlirType type) { + return wrap(unwrap(type).cast().getTypeData()); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 20d9c919c..cd7eb0a55 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -69,6 +69,7 @@ __all__ = [ "Module", "NamedAttribute", "NoneType", + "OpaqueType", "OpAttributeMap", "OpOperandList", "OpResult", @@ -820,6 +821,17 @@ class NoneType(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +class OpaqueType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> OpaqueType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + @property + def dialect_namespace(self) -> str: ... + @property + def data(self) -> str: ... + class OpAttributeMap: def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... From 69d8f2f8343959cfd43adc68b521ec953fb5c5da Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 9 Jun 2022 11:10:32 +0200 Subject: [PATCH 296/915] [mlir] Introduce Transform ops for loops Introduce transform ops for "for" loops, in particular for peeling, software pipelining and unrolling, along with a couple of "IR navigation" ops. These ops are intended to be generalized to different kinds of loops when possible and therefore use the "loop" prefix. They currently live in the SCF dialect as there is no clear place to put transform ops that may span across several dialects, this decision is postponed until the ops actually need to handle non-SCF loops. Additionally refactor some common utilities for transform ops into trait or interface methods, and change the loop pipelining to be a returning pattern. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D127300 --- mlir/python/CMakeLists.txt | 10 ++ .../mlir/dialects/SCFLoopTransformOps.td | 21 ++++ .../mlir/dialects/_loop_transform_ops_ext.py | 113 ++++++++++++++++++ mlir/python/mlir/dialects/transform/loop.py | 5 + 4 files changed, 149 insertions(+) create mode 100644 mlir/python/mlir/dialects/SCFLoopTransformOps.td create mode 100644 mlir/python/mlir/dialects/_loop_transform_ops_ext.py create mode 100644 mlir/python/mlir/dialects/transform/loop.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 17048e8cb..13b35f15b 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -125,6 +125,16 @@ declare_mlir_dialect_python_bindings( dialects/transform/__init__.py DIALECT_NAME transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SCFLoopTransformOps.td + SOURCES + dialects/_loop_transform_ops_ext.py + dialects/transform/loop.py + DIALECT_NAME transform + EXTENSION_NAME loop_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td new file mode 100644 index 000000000..5ef07fc7a --- /dev/null +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -0,0 +1,21 @@ +//===-- SCFLoopTransformOps.td -----------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the loop transform ops +// provided by the SCF (and other) dialects. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS +#define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" + +#endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py new file mode 100644 index 000000000..7452c4243 --- /dev/null +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], + default_value: int = None): + if isinstance(arg, IntegerAttr): + return arg + + if arg is None: + assert default_value is not None, "must provide default value" + arg = default_value + + return IntegerAttr.get(IntegerType.get_signless(64), arg) + + +class GetParentForOp: + """Extension for GetParentForOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + num_loops: int = 1, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + num_loops=_get_int64_attr(num_loops, default_value=1), + ip=ip, + loc=loc) + + +class LoopOutlineOp: + """Extension for LoopOutlineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + func_name=(func_name if isinstance(func_name, StringAttr) else + StringAttr.get(func_name)), + ip=ip, + loc=loc) + + +class LoopPeelOp: + """Extension for LoopPeelOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + fail_if_already_divisible=(fail_if_already_divisible if isinstance( + fail_if_already_divisible, BoolAttr) else + BoolAttr.get(fail_if_already_divisible)), + ip=ip, + loc=loc) + + +class LoopPipelineOp: + """Extension for LoopPipelineOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + iteration_interval=_get_int64_attr(iteration_interval, default_value=1), + read_latency=_get_int64_attr(read_latency, default_value=10), + ip=ip, + loc=loc) + + +class LoopUnrollOp: + """Extension for LoopUnrollOp.""" + + def __init__(self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None): + super().__init__( + _get_op_result_or_value(target), + factor=_get_int64_attr(factor), + ip=ip, + loc=loc) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py new file mode 100644 index 000000000..86f72788d --- /dev/null +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._loop_transform_ops_gen import * From 36dd69fc61f1193549e95c51f7979e77f9cf7c9a Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 9 Jun 2022 21:36:39 +0200 Subject: [PATCH 297/915] [mlir][bufferization] Add optional `copy` operand to AllocTensorOp If `copy` is specified, the newly allocated buffer is initialized with the given contents. Also add an optional `escape` attribute to indicate whether the buffer of the tensor may be returned from the parent block (aka. "escape") after bufferization. This change is in preparation of connecting One-Shot Bufferize to the sparse compiler. Differential Revision: https://reviews.llvm.org/D126570 --- mlir/python/mlir/dialects/_bufferization_ops_ext.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py index c720844af..23f78fc80 100644 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -18,15 +18,20 @@ class AllocTensorOp: def __init__(self, tensor_type: Type, dynamic_sizes: Sequence[Value], + copy: Value, + escape: BoolAttr, *, loc=None, ip=None): """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" context = get_default_loc_context(loc) + attributes = {} + if escape: + attributes["escape"] = escape op = self.build_generic( results=[tensor_type], - operands=dynamic_sizes, - attributes={}, + operands=[dynamic_sizes, copy], + attributes=attributes, loc=loc, ip=ip) OpView.__init__(self, op) From f1cba83efe8f2ed6e53a25730195222f09749abb Mon Sep 17 00:00:00 2001 From: Mogball Date: Thu, 9 Jun 2022 21:33:41 +0000 Subject: [PATCH 298/915] [mlir][gpu] Move GPU headers into IR/ and Transforms/ Depends on D127350 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D127352 --- mlir/include/mlir-c/Dialect/GPU.h | 2 +- mlir/lib/CAPI/Dialect/GPU.cpp | 2 +- mlir/lib/CAPI/Dialect/GPUPasses.cpp | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h index e4797a7ee..cf4e899c5 100644 --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -23,6 +23,6 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu); } #endif -#include "mlir/Dialect/GPU/Passes.capi.h.inc" +#include "mlir/Dialect/GPU/Transforms/Passes.capi.h.inc" #endif // MLIR_C_DIALECT_GPU_H diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp index 0de2cfa33..cd58f0e24 100644 --- a/mlir/lib/CAPI/Dialect/GPU.cpp +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -8,6 +8,6 @@ #include "mlir-c/Dialect/GPU.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect) diff --git a/mlir/lib/CAPI/Dialect/GPUPasses.cpp b/mlir/lib/CAPI/Dialect/GPUPasses.cpp index 4ec167f88..5128c63ec 100644 --- a/mlir/lib/CAPI/Dialect/GPUPasses.cpp +++ b/mlir/lib/CAPI/Dialect/GPUPasses.cpp @@ -7,11 +7,11 @@ //===----------------------------------------------------------------------===// #include "mlir/CAPI/Pass.h" -#include "mlir/Dialect/GPU/Passes.h" +#include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Pass/Pass.h" // Must include the declarations as they carry important visibility attributes. -#include "mlir/Dialect/GPU/Passes.capi.h.inc" +#include "mlir/Dialect/GPU/Transforms/Passes.capi.h.inc" using namespace mlir; @@ -19,7 +19,7 @@ using namespace mlir; extern "C" { #endif -#include "mlir/Dialect/GPU/Passes.capi.cpp.inc" +#include "mlir/Dialect/GPU/Transforms/Passes.capi.cpp.inc" #ifdef __cplusplus } From 7a78d23999b514ccbbe0ee81c2640ce5fdfbccbb Mon Sep 17 00:00:00 2001 From: agostini01 Date: Fri, 10 Jun 2022 20:42:16 +0000 Subject: [PATCH 299/915] [mlir][py-bindings] Fix include issue introduced by D127352 Using: -DMLIR_ENABLE_BINDINGS_PYTHON=ON Resulted in a failed build due to changes implemented by https://reviews.llvm.org/D127352 This updates the include line Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D127523 --- mlir/python/mlir/dialects/GPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/GPUOps.td b/mlir/python/mlir/dialects/GPUOps.td index bf0980f29..4e23d322f 100644 --- a/mlir/python/mlir/dialects/GPUOps.td +++ b/mlir/python/mlir/dialects/GPUOps.td @@ -10,6 +10,6 @@ #define PYTHON_BINDINGS_GPU_OPS include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/GPU/GPUOps.td" +include "mlir/Dialect/GPU/IR/GPUOps.td" #endif From ac78d7042393cc4aff4e92d744858b832576a4ac Mon Sep 17 00:00:00 2001 From: Mogball Date: Mon, 13 Jun 2022 06:50:55 +0000 Subject: [PATCH 300/915] [mlir] (NFC) Clean up bazel and CMake target names All dialect targets in bazel have been named *Dialect and all dialect targets in CMake have been named MLIR*Dialect. --- mlir/lib/CAPI/Dialect/CMakeLists.txt | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 37bc121f2..e5173ffd3 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -8,7 +8,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIAsync LINK_LIBS PUBLIC MLIRCAPIIR - MLIRAsync + MLIRAsyncDialect MLIRAsyncTransforms MLIRPass ) @@ -19,7 +19,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIControlFlow PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRControlFlow + MLIRControlFlowDialect ) add_mlir_upstream_c_api_library(MLIRCAPIGPU @@ -42,7 +42,7 @@ add_mlir_upstream_c_api_library(MLIRCAPILLVM PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRLLVMIR + MLIRLLVMDialect ) add_mlir_upstream_c_api_library(MLIRCAPILinalg @@ -55,7 +55,7 @@ add_mlir_upstream_c_api_library(MLIRCAPILinalg LINK_LIBS PUBLIC MLIRCAPIIR - MLIRLinalg + MLIRLinalgDialect MLIRPass MLIRLinalgTransforms ) @@ -66,7 +66,7 @@ add_mlir_upstream_c_api_library(MLIRCAPISCF PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRSCF + MLIRSCFDialect ) add_mlir_upstream_c_api_library(MLIRCAPIShape @@ -75,7 +75,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIShape PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRShape + MLIRShapeDialect ) add_mlir_upstream_c_api_library(MLIRCAPISparseTensor @@ -85,7 +85,7 @@ add_mlir_upstream_c_api_library(MLIRCAPISparseTensor PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRSparseTensor + MLIRSparseTensorDialect MLIRSparseTensorTransforms ) @@ -95,7 +95,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIFunc PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRFunc + MLIRFuncDialect ) add_mlir_upstream_c_api_library(MLIRCAPITensor @@ -104,7 +104,7 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRTensor + MLIRTensorDialect ) add_mlir_upstream_c_api_library(MLIRCAPIQuant @@ -113,7 +113,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRQuant + MLIRQuantDialect ) add_mlir_upstream_c_api_library(MLIRCAPIPDL @@ -122,5 +122,5 @@ add_mlir_upstream_c_api_library(MLIRCAPIPDL PARTIAL_SOURCES_INTENDED LINK_LIBS PUBLIC MLIRCAPIIR - MLIRPDL + MLIRPDLDialect ) From d16e552caf3b788ffd5ba5350296f1a380e912ce Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 13 Jun 2022 20:37:44 +0200 Subject: [PATCH 301/915] [mlir][linalg] Add conv_2d_nhwc_fhwc to core_named_ops.py So it doesn't disappear when running the generator. --- .../linalg/opdsl/ops/core_named_ops.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 023a95dc7..291a4be22 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -293,6 +293,29 @@ def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) +@linalg_structured_op +def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, + D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) + + @linalg_structured_op def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), From 4d5d42fb78b4e0a5e04232955ea0d2aacc72cc27 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 13 Jun 2022 22:03:56 +0200 Subject: [PATCH 302/915] [mlir][linalg] Add named ops for depthwise 3d convolution Also complete the set by adding a variant of depthwise 1d convolution with the multiplier != 1. Differential Revision: https://reviews.llvm.org/D127687 --- .../linalg/opdsl/ops/core_named_ops.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 291a4be22..f553c3880 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -414,6 +414,26 @@ def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, TypeFn.cast_signed(U, K[D.kw, D.ic]) +@linalg_structured_op +def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, + S.IC), + K=TensorDef(T2, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OW, S.IC, S.CM, + output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.cm, D.kw) + O[D.n, D.ow, D.ic, D.cm] += \ + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ + TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) + + @linalg_structured_op def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), @@ -536,6 +556,64 @@ def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N, TypeFn.cast_signed(U, KZp))) +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, + output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( + U, K[D.kd, D.kh, D.kw, D.ic]) + + +@linalg_structured_op +def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, + S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, + output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( + U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) + + @linalg_structured_op def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), From 6a6d311bf94bbed80bfb86abc8f8639777cba926 Mon Sep 17 00:00:00 2001 From: Mark Browning Date: Wed, 15 Jun 2022 21:59:57 -0700 Subject: [PATCH 303/915] [mlir][python] Actually set UseLocalScope printing flag The useLocalScope printing flag has been passed around between pybind methods, but doesn't actually enable the corresponding printing flag. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D127907 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 37ca3b953..98566ddea 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1002,6 +1002,8 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); if (printGenericOpForm) mlirOpPrintingFlagsPrintGenericOpForm(flags); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), From e5a312d10ebe9e5b3498c470b9abb31d319ae5c4 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 15 Jun 2022 16:00:51 -0700 Subject: [PATCH 304/915] [mlir][complex] Add Python bindings for complex ops. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D127916 --- mlir/python/CMakeLists.txt | 8 ++++++++ mlir/python/mlir/dialects/ComplexOps.td | 15 +++++++++++++++ mlir/python/mlir/dialects/complex.py | 5 +++++ 3 files changed, 28 insertions(+) create mode 100644 mlir/python/mlir/dialects/ComplexOps.td create mode 100644 mlir/python/mlir/dialects/complex.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 13b35f15b..4581dd667 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -81,6 +81,14 @@ declare_mlir_dialect_python_bindings( dialects/_builtin_ops_ext.py DIALECT_NAME builtin) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ComplexOps.td + SOURCES + dialects/complex.py + DIALECT_NAME complex) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/ComplexOps.td b/mlir/python/mlir/dialects/ComplexOps.td new file mode 100644 index 000000000..6fd846ba6 --- /dev/null +++ b/mlir/python/mlir/dialects/ComplexOps.td @@ -0,0 +1,15 @@ +//===-- ComplexOps.td - Entry point for ComplexOps bindings ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_COMPLEX_OPS +#define PYTHON_BINDINGS_COMPLEX_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Complex/IR/ComplexOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/complex.py b/mlir/python/mlir/dialects/complex.py new file mode 100644 index 000000000..ca81173cf --- /dev/null +++ b/mlir/python/mlir/dialects/complex.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._complex_ops_gen import * From 55a9de0fb71392de2de8eb025a4540b47437244b Mon Sep 17 00:00:00 2001 From: bixia1 Date: Thu, 16 Jun 2022 14:27:26 -0700 Subject: [PATCH 305/915] [mlir][linalg] Extend opdsl to support operations on complex types. Linalg opdsl now supports negf/add/sub/mul on complex types. Add a test. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D128010 --- .../mlir/dialects/linalg/opdsl/lang/emitter.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 2e71e561a..cc99081b4 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -10,6 +10,7 @@ from .... import linalg from .... import math from .... import arith +from .... import complex from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values from .scalar_expr import * @@ -408,6 +409,8 @@ def _unary_floor(self, x: Value) -> Value: def _unary_negf(self, x: Value) -> Value: if _is_floating_point_type(x.type): return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result raise NotImplementedError("Unsupported 'negf' operand: {x}") def _binary_add(self, lhs: Value, rhs: Value) -> Value: @@ -415,6 +418,8 @@ def _binary_add(self, lhs: Value, rhs: Value) -> Value: return arith.AddFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") def _binary_sub(self, lhs: Value, rhs: Value) -> Value: @@ -422,6 +427,8 @@ def _binary_sub(self, lhs: Value, rhs: Value) -> Value: return arith.SubFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") def _binary_mul(self, lhs: Value, rhs: Value) -> Value: @@ -429,6 +436,8 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value: return arith.MulFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: @@ -512,6 +521,10 @@ def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, block_arg_types.append(element_or_self_type) +def _is_complex_type(t: Type) -> bool: + return ComplexType.isinstance(t) + + def _is_floating_point_type(t: Type) -> bool: # TODO: Create a FloatType in the Python API and implement the switch # there. From f35621049201dff81f517f6e836cd57fd80c1b42 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 17 Jun 2022 15:47:15 +0200 Subject: [PATCH 306/915] [mlir] move SCF headers to SCF/{IR,Transforms} respectively This aligns the SCF dialect file layout with the majority of the dialects. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D128049 --- mlir/lib/CAPI/Dialect/SCF.cpp | 2 +- mlir/python/mlir/dialects/SCFOps.td | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SCF.cpp b/mlir/lib/CAPI/Dialect/SCF.cpp index c1dca6d21..17751b1c9 100644 --- a/mlir/lib/CAPI/Dialect/SCF.cpp +++ b/mlir/lib/CAPI/Dialect/SCF.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir-c/Dialect/SCF.h" #include "mlir/CAPI/Registration.h" diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td index 855482d4a..58f337e23 100644 --- a/mlir/python/mlir/dialects/SCFOps.td +++ b/mlir/python/mlir/dialects/SCFOps.td @@ -10,6 +10,6 @@ #define PYTHON_BINDINGS_SCF_OPS include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/SCF/SCFOps.td" +include "mlir/Dialect/SCF/IR/SCFOps.td" #endif From 90c422dfda264bc51f20c85413376f757d6d24aa Mon Sep 17 00:00:00 2001 From: gpetters94 Date: Tue, 21 Jun 2022 16:51:18 +0000 Subject: [PATCH 307/915] Adding a named op for grouped convolutions --- .../linalg/opdsl/ops/core_named_ops.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index f553c3880..7ffe13c75 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -366,6 +366,27 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) +@linalg_structured_op +def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: FGCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, From e0753e0b2c731b226de7fe5ceb1ba3b88cf1a085 Mon Sep 17 00:00:00 2001 From: Stella Stamenova Date: Tue, 28 Jun 2022 10:39:13 -0700 Subject: [PATCH 308/915] [mlir] Leverage CMake interface libraries for mlir python This is already partially the case, but we can rely more heavily on interface libraries and how they are imported/exported in other to simplify the implementation of the mlir python functions in Cmake. This change also makes a couple of other changes: 1) Add a new CMake function which handles "pure" sources. This was done inline previously 2) Moves the headers associated with CAPI libraries to the libraries themselves. These were previously managed in a separate source target. They can now be added directly to the CAPI libraries using DECLARED_HEADERS. 3) Cleanup some dependencies that showed up as an issue during the refactor This is a big CMake change that should produce no impact on the build of mlir and on the produced *build tree*. However, this change fixes an issue with the *install tree* of mlir which was previously unusable for projects like torch-mlir because both the "pure" and "extension" targets were pointing to either the build or source trees. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D128230 --- mlir/python/CMakeLists.txt | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 4581dd667..dc831c9cb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -46,10 +46,9 @@ declare_mlir_python_sources(MLIRPythonSources.Passes transforms/*.py ) -declare_mlir_python_sources(MLIRPythonCAPIHeaderSources - ROOT_DIR "${MLIR_SOURCE_DIR}/include" +declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources + ROOT_DIR "${MLIR_MAIN_INCLUDE_DIR}" SOURCES_GLOB "mlir-c/*.h" - DEST_PREFIX "_mlir_libs/include" ) ################################################################################ @@ -503,6 +502,8 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI INSTALL_DESTINATION python_packages/mlir_core/mlir/_mlir_libs OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" RELATIVE_INSTALL_ROOT "../../../.." + DECLARED_HEADERS + MLIRPythonCAPI.HeaderSources DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.AllPassesRegistration @@ -520,8 +521,7 @@ add_mlir_python_modules(MLIRPythonModules DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.AllPassesRegistration - MLIRPythonCAPIHeaderSources ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS MLIRPythonCAPI - ) +) From b4e0e1a0b749b0fdadfc19bcf0b25b4af62177f9 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 13:10:40 +0200 Subject: [PATCH 309/915] [mlir] Structured transforms: introduce op splitting Introduce a new transformation on structured ops that splits the iteration space into two parts along the specified dimension. The index at which the splitting happens may be static or dynamic. This transformation can be seen as a rudimentary form of index-set splitting that only supports the splitting along hyperplanes parallel to the iteration space hyperplanes, and is therefore decomposable into per-dimension application. It is a key low-level transformation that enables independent scheduling for different parts of the iteration space of the same op, which hasn't been possible previously. It may be used to implement, e.g., multi-sized tiling. In future, peeling can be implemented as a combination of split-off amount computation and splitting. The transformation is conceptually close to tiling in its separation of the iteration and data spaces, but cannot be currently implemented on top of TilingInterface as the latter does not properly support `linalg.index` offsetting. Note that the transformation intentionally bypasses folding of `tensor.extract_slice` operations when creating them as this folding was found to prevent repeated splitting of the same operation because due to internal assumptions about extract/insert_slice combination in dialect utilities. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129090 --- mlir/include/mlir-c/BuiltinTypes.h | 9 ++++ mlir/lib/Bindings/Python/IRTypes.cpp | 9 ++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 6 +++ .../dialects/_structured_transform_ops_ext.py | 47 ++++++++++++++++--- 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 495591464..d1083f932 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -150,10 +150,19 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, /// in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// Returns the value indicating a dynamic size in a shaped type. Prefer +/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(); + /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); +/// Returns the value indicating a dynamic stride or offset in a shaped type. +/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with +/// this value. +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(); + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index d93d9f66b..153664d07 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -301,6 +301,15 @@ class PyShapedType : public PyConcreteType { return shape; }, "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "_get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); } private: diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 446f9c4d4..be44b76e8 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -149,6 +149,8 @@ int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { return unwrap(type).cast().getDimSize(static_cast(dim)); } +int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; } + bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } @@ -157,6 +159,10 @@ bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { return ShapedType::isDynamicStrideOrOffset(val); } +int64_t mlirShapedTypeGetDynamicStrideOrOffset() { + return ShapedType::kDynamicStrideOrOffset; +} + //===----------------------------------------------------------------------===// // Vector type. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index e5a2a4731..beef9240d 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -15,6 +15,12 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] +def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + return value + + def _get_array_attr( values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: """Creates an array attribute from its operand.""" @@ -41,13 +47,7 @@ def _get_int_array_attr( if isinstance(values, ArrayAttr): return values - attributes = [] - for value in values: - if isinstance(value, IntegerAttr): - attributes.append(value) - else: - attributes.append(IntegerAttr.get(IntegerType.get_signless(64), value)) - return ArrayAttr.get(attributes) + return ArrayAttr.get([_get_int64_attr(v) for v in values]) def _get_int_int_array_attr( @@ -152,6 +152,39 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) +class SplitOp: + """Specialization for SplitOp class.""" + + def __init__(self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None): + dimension = _get_int64_attr(dimension) + if isinstance(split_point, int): + split_point = _get_int64_attr(split_point) + + if isinstance(split_point, Attribute): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) + dynamic_split_point = _get_op_result_or_value(split_point) + + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + pdl_operation_type, + _get_op_result_or_value(target), + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip) + + class TileOp: """Specialization for TileOp class.""" From fa95c252ba1b8d533c152e474f792b5159dc2728 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 13:11:34 +0200 Subject: [PATCH 310/915] [mlir] Transform dialect: introduce merge_handles op This Transform dialect op allows one to merge the lists of Payload IR operations pointed to by several handles into a single list associated with one handle. This is an important Transform dialect usability improvement for cases where transformations may temporarily diverge for different groups of Payload IR ops before converging back to the same script. Without this op, several copies of the trailing transformations would have to be present in the transformation script. Depends On D129090 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129110 --- mlir/python/mlir/dialects/_transform_ops_ext.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 138195dca..ca45ab7e2 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -28,6 +28,21 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): ip=ip) +class MergeHandlesOp: + + def __init__(self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None): + super().__init__( + pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip) + + class PDLMatchOp: def __init__(self, From 91b793b3254ac42b20348be9ade5fef1485fbc8e Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 11 Jul 2022 15:37:03 -0400 Subject: [PATCH 311/915] Fix an issue with grouped conv2d op --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7ffe13c75..7dd3f9495 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -370,7 +370,7 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. @@ -386,7 +386,7 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, From 5282a5daa81396b6ba7f7c2eb5cb7b68f7fed59a Mon Sep 17 00:00:00 2001 From: Nirvedh Date: Mon, 11 Jul 2022 20:03:16 +0000 Subject: [PATCH 312/915] Revert "Fix an issue with grouped conv2d op" This reverts commit 91b793b3254ac42b20348be9ade5fef1485fbc8e. --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7dd3f9495..7ffe13c75 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -370,7 +370,7 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. @@ -386,7 +386,7 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, From 38eeed5424a1dfa78c1c9aa0d9a6c078a16cad1e Mon Sep 17 00:00:00 2001 From: George Petterson Date: Mon, 11 Jul 2022 15:37:03 -0400 Subject: [PATCH 313/915] Fix an issue with grouped conv2d op Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D128880 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7ffe13c75..7dd3f9495 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -370,7 +370,7 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. @@ -386,7 +386,7 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, From 7035b6e5e665869bbc68bafc1b1d3caacb4b963b Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 15:55:23 +0200 Subject: [PATCH 314/915] [mlir] Add ReplicateOp to the Transform dialect This handle manipulation operation allows one to define a new handle that is associated with a the same payload IR operations N times, where N can be driven by the size of payload IR operation list associated with another handle. This can be seen as a sort of broadcast that can be used to ensure the lists associated with two handles have equal numbers of payload IR ops as expected by many pairwise transform operations. Introduce an additional "expensive" check that guards against consuming a handle that is assocaited with the same payload IR operation more than once as this is likely to lead to double-free or other undesired effects. Depends On D129110 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129216 --- mlir/python/mlir/dialects/_transform_ops_ext.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index ca45ab7e2..e75d6b5f9 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -59,6 +59,22 @@ def __init__(self, ip=ip) +class ReplicateOp: + + def __init__(self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None): + super().__init__( + [pdl.OperationType.get()] * len(handles), + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip) + + class SequenceOp: @overload From 623edebfa0c6e7ec9cac2677a816f316f82b9563 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 15:55:44 +0200 Subject: [PATCH 315/915] [mlir] Allow Tile transform op to take dynamic sizes Extend the definition of the Tile structured transform op to enable it accepting handles to operations that produce tile sizes at runtime. This is useful by itself and prepares for more advanced tiling strategies. Note that the changes are relevant only to the transform dialect, the tiling transformation itself already supports dynamic sizes. Depends On D129216 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129217 --- .../dialects/_structured_transform_ops_ext.py | 28 +++++++++++++++++-- 1 file changed, 25 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index beef9240d..b6e078fc7 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -191,18 +191,40 @@ class TileOp: def __init__(self, target: Union[Operation, Value], *, - sizes: OptionalIntList = None, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, interchange: OptionalIntList = None, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - sizes_attr = _get_int_array_attr(sizes) + i64_type = IntegerType.get_signless(64) + + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(IntegerAttr.get(i64_type, size)) + elif isinstance(size, IntegerAttr): + static_sizes.append(size) + else: + static_sizes.append( + IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = ArrayAttr.get(static_sizes) + num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) super().__init__( pdl_operation_type, [pdl_operation_type] * num_loops, _get_op_result_or_value(target), - sizes=sizes_attr, + dynamic_sizes=dynamic_sizes, + static_sizes=sizes_attr, interchange=_get_int_array_attr(interchange) if interchange else None, loc=loc, ip=ip) From 5e9cca31a68726a06c3838e149617ec026d97e16 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 7 Jul 2022 15:56:06 +0200 Subject: [PATCH 316/915] [mlir] Transform op for multitile size generation Introduce a structured transform op that emits IR computing the multi-tile sizes with requested parameters (target size and divisor) for the given structured op. The sizes may fold to arithmetic constant operations when the shape is constant. These operations may then be used to call the existing tiling transformation with a single non-zero dynamic size (i.e. perform strip-mining) for each of the dimensions separately, thus achieving multi-size tiling with optional loop interchange. A separate test exercises the entire script. Depends On D129217 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D129287 --- .../dialects/_structured_transform_ops_ext.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index b6e078fc7..95bf2cc99 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -110,6 +110,29 @@ def __init__(self, ip=ip) +class MultiTileSizesOp: + """Specialization for MultitileSizesOp class.""" + + def __init__(self, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Union[int, IntegerAttr]] = None, + loc=None, + ip=None): + super().__init__( + pdl.OperationType.get(), + pdl.OperationType.get(), + pdl.OperationType.get(), + _get_op_result_or_value(target), + dimension=_get_int64_attr(dimension), + target_size=_get_int64_attr(target_size), + divisor=_get_int64_attr(divisor if divisor else 1), + loc=loc, + ip=ip) + + class PadOp: """Specialization for PadOp class.""" From 8e39ee649614ba857a6aaf938c0646fc53d0765b Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 14 Jul 2022 13:31:47 -0700 Subject: [PATCH 317/915] [mlir] (NFC) run clang-format on all files --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 582855a33..588721c3f 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -252,7 +252,7 @@ class pure_subclass { } template - pure_subclass &def(const char *name, Func &&f, const Extra &... extra) { + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -262,7 +262,7 @@ class pure_subclass { template pure_subclass &def_property_readonly(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { py::cpp_function cf( std::forward(f), py::name(name), py::is_method(thisClass), py::sibling(py::getattr(thisClass, name, py::none())), extra...); @@ -274,7 +274,7 @@ class pure_subclass { template pure_subclass &def_staticmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); @@ -287,7 +287,7 @@ class pure_subclass { template pure_subclass &def_classmethod(const char *name, Func &&f, - const Extra &... extra) { + const Extra &...extra) { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); From b063ff053204a591590a2278c64d0c3eaa5feb91 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 16 Jul 2022 16:09:03 -0700 Subject: [PATCH 318/915] [mlir] Overhaul C/Python registration APIs to properly scope registration/loading activities. Since the very first commits, the Python and C MLIR APIs have had mis-placed registration/load functionality for dialects, extensions, etc. This was done pragmatically in order to get bootstrapped and then just grew in. Downstreams largely bypass and do their own thing by providing various APIs to register things they need. Meanwhile, the C++ APIs have stabilized around this and it would make sense to follow suit. The thing we have observed in canonical usage by downstreams is that each downstream tends to have native entry points that configure its installation to its preferences with one-stop APIs. This patch leans in to this approach with `RegisterEverything.h` and `mlir._mlir_libs._mlirRegisterEverything` being the one-stop entry points for the "upstream packages". The `_mlir_libs.__init__.py` now allows customization of the environment and Context by adding "initialization modules" to the `_mlir_libs` package. If present, `_mlirRegisterEverything` is treated as such a module. Others can be added by downstreams by adding a `_site_initialize_{i}.py` module, where '{i}' is a number starting with zero. The number will be incremented and corresponding module loaded until one is not found. Initialization modules can: * Perform load time customization to the global environment (i.e. registering passes, hooks, etc). * Define a `register_dialects(registry: DialectRegistry)` function that can extend the `DialectRegistry` that will be used to bootstrap the `Context`. * Define a `context_init_hook(context: Context)` function that will be added to a list of callbacks which will be invoked after dialect registration during `Context` initialization. Note that the `MLIRPythonExtension.RegisterEverything` is not included by default when building a downstream (its corresponding behavior was prior). For downstreams which need the default MLIR initialization to take place, they must add this back in to their Python CMake build just like they add their own components (i.e. to `add_mlir_python_common_capi_library` and `add_mlir_python_modules`). It is perfectly valid to not do this, in which case, only the things explicitly depended on and initialized by downstreams will be built/packaged. If the downstream has not been set up for this, it is recommended to simply add this back for the time being and pay the build time/package size cost. CMake changes: * `MLIRCAPIRegistration` -> `MLIRCAPIRegisterEverything` (renamed to signify what it does and force an evaluation: a number of places were incidentally linking this very expensive target) * `MLIRPythonSoure.Passes` removed (without replacement: just drop) * `MLIRPythonExtension.AllPassesRegistration` removed (without replacement: just drop) * `MLIRPythonExtension.Conversions` removed (without replacement: just drop) * `MLIRPythonExtension.Transforms` removed (without replacement: just drop) Header changes: * `mlir-c/Registration.h` is deleted. Dialect registration functionality is now in `IR.h`. Registration of upstream features are in `mlir-c/RegisterEverything.h`. When updating MLIR and a couple of downstreams, I found that proper usage was commingled so required making a choice vs just blind S&R. Python APIs removed: * mlir.transforms and mlir.conversions (previously only had an __init__.py which indirectly triggered `mlirRegisterTransformsPasses()` and `mlirRegisterConversionPasses()` respectively). Downstream impact: Remove these imports if present (they now happen as part of default initialization). * mlir._mlir_libs._all_passes_registration, mlir._mlir_libs._mlirTransforms, mlir._mlir_libs._mlirConversions. Downstream impact: None expected (these were internally used). C-APIs changed: * mlirRegisterAllDialects(MlirContext) now takes an MlirDialectRegistry instead. It also used to trigger loading of all dialects, which was already marked with a TODO to remove -- it no longer does, and for direct use, dialects must be explicitly loaded. Downstream impact: Direct C-API users must ensure that needed dialects are loaded or call `mlirContextLoadAllAvailableDialects(MlirContext)` to emulate the prior behavior. Also see the `ir.c` test case (e.g. ` mlirContextGetOrLoadDialect(ctx, mlirStringRefCreateFromCString("func"));`). * mlirDialectHandle* APIs were moved from Registration.h (which now is restricted to just global/upstream registration) to IR.h, arguably where it should have been. Downstream impact: include correct header (likely already doing so). C-APIs added: * mlirContextLoadAllAvailableDialects(MlirContext): Corresponds to C++ API with the same purpose. Python APIs added: * mlir.ir.DialectRegistry: Mapping for an MlirDialectRegistry. * mlir.ir.Context.append_dialect_registry(MlirDialectRegistry) * mlir.ir.Context.load_all_available_dialects() * mlir._mlir_libs._mlirAllRegistration: New native extension that exposes a `register_dialects(MlirDialectRegistry)` entry point and performs all upstream pass/conversion/transforms registration on init. In this first step, we eagerly load this as part of the __init__.py and use it to monkey patch the Context to emulate prior behavior. * Type caster and capsule support for MlirDialectRegistry This should make it possible to build downstream Python dialects that only depend on a subset of MLIR. See: https://github.com/llvm/llvm-project/issues/56037 Here is an example PR, minimally adapting IREE to these changes: https://github.com/iree-org/iree/pull/9638/files In this situation, IREE is opting to not link everything, since it is already configuring the Context to its liking. For projects that would just like to not think about it and pull in everything, add `MLIRPythonExtension.RegisterEverything` to the list of Python sources getting built, and the old behavior will continue. Reviewed By: mehdi_amini, ftynse Differential Revision: https://reviews.llvm.org/D128593 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 24 ++++++ mlir/include/mlir-c/Dialect/Async.h | 2 +- mlir/include/mlir-c/Dialect/ControlFlow.h | 2 +- mlir/include/mlir-c/Dialect/Func.h | 2 +- mlir/include/mlir-c/Dialect/GPU.h | 2 +- mlir/include/mlir-c/Dialect/LLVM.h | 1 - mlir/include/mlir-c/Dialect/Linalg.h | 2 +- mlir/include/mlir-c/Dialect/PDL.h | 1 - mlir/include/mlir-c/Dialect/Quant.h | 1 - mlir/include/mlir-c/Dialect/SCF.h | 2 +- mlir/include/mlir-c/Dialect/Shape.h | 2 +- mlir/include/mlir-c/Dialect/SparseTensor.h | 2 +- mlir/include/mlir-c/Dialect/Tensor.h | 2 +- mlir/include/mlir-c/IR.h | 46 +++++++++++ mlir/include/mlir-c/Pass.h | 1 - mlir/include/mlir-c/RegisterEverything.h | 38 +++++++++ mlir/include/mlir-c/Registration.h | 75 ----------------- .../mlir/Bindings/Python/PybindAdaptors.h | 19 +++++ mlir/include/mlir/CAPI/Registration.h | 1 - .../Bindings/Python/AllPassesRegistration.cpp | 22 ----- .../Python/Conversions/Conversions.cpp | 22 ----- mlir/lib/Bindings/Python/IRCore.cpp | 43 ++++++++-- mlir/lib/Bindings/Python/IRModule.h | 26 ++++++ .../Bindings/Python/RegisterEverything.cpp | 26 ++++++ .../Bindings/Python/Transforms/Transforms.cpp | 22 ----- mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Conversion/CMakeLists.txt | 3 + mlir/lib/CAPI/IR/IR.cpp | 4 + .../CMakeLists.txt | 10 ++- .../RegisterEverything.cpp} | 10 +-- mlir/python/CMakeLists.txt | 72 ++++++----------- mlir/python/mlir/_mlir_libs/__init__.py | 80 +++++++++++++++++-- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 5 ++ .../mlir/all_passes_registration/__init__.py | 5 -- mlir/python/mlir/conversions/__init__.py | 7 -- mlir/python/mlir/transforms/__init__.py | 7 -- 36 files changed, 346 insertions(+), 245 deletions(-) create mode 100644 mlir/include/mlir-c/RegisterEverything.h delete mode 100644 mlir/include/mlir-c/Registration.h delete mode 100644 mlir/lib/Bindings/Python/AllPassesRegistration.cpp delete mode 100644 mlir/lib/Bindings/Python/Conversions/Conversions.cpp create mode 100644 mlir/lib/Bindings/Python/RegisterEverything.cpp delete mode 100644 mlir/lib/Bindings/Python/Transforms/Transforms.cpp rename mlir/lib/CAPI/{Registration => RegisterEverything}/CMakeLists.txt (76%) rename mlir/lib/CAPI/{Registration/Registration.cpp => RegisterEverything/RegisterEverything.cpp} (69%) delete mode 100644 mlir/python/mlir/all_passes_registration/__init__.py delete mode 100644 mlir/python/mlir/conversions/__init__.py delete mode 100644 mlir/python/mlir/transforms/__init__.py diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 86b57cd73..b877f94aa 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -64,6 +64,8 @@ MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") #define MLIR_PYTHON_CAPSULE_CONTEXT \ MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY \ + MAKE_MLIR_PYTHON_QUALNAME("ir.DialectRegistry._CAPIPtr") #define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \ MAKE_MLIR_PYTHON_QUALNAME("execution_engine.ExecutionEngine._CAPIPtr") #define MLIR_PYTHON_CAPSULE_INTEGER_SET \ @@ -172,6 +174,28 @@ static inline MlirContext mlirPythonCapsuleToContext(PyObject *capsule) { return context; } +/** Creates a capsule object encapsulating the raw C-API MlirDialectRegistry. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the context in any way. + */ +static inline PyObject * +mlirPythonDialectRegistryToCapsule(MlirDialectRegistry registry) { + return PyCapsule_New(registry.ptr, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY, + NULL); +} + +/** Extracts an MlirDialectRegistry from a capsule as produced from + * mlirPythonDialectRegistryToCapsule. If the capsule is not of the right type, + * then a null context is returned (as checked via mlirContextIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirDialectRegistry +mlirPythonCapsuleToDialectRegistry(PyObject *capsule) { + void *ptr = + PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY); + MlirDialectRegistry registry = {ptr}; + return registry; +} + /** Creates a capsule object encapsulating the raw C-API MlirLocation. * The returned capsule does not extend or affect ownership of any Python * objects that reference the location in any way. */ diff --git a/mlir/include/mlir-c/Dialect/Async.h b/mlir/include/mlir-c/Dialect/Async.h index 50b6413ef..e4e32f86a 100644 --- a/mlir/include/mlir-c/Dialect/Async.h +++ b/mlir/include/mlir-c/Dialect/Async.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_ASYNC_H #define MLIR_C_DIALECT_ASYNC_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/ControlFlow.h b/mlir/include/mlir-c/Dialect/ControlFlow.h index 1ca7054d6..6d5ff8c3d 100644 --- a/mlir/include/mlir-c/Dialect/ControlFlow.h +++ b/mlir/include/mlir-c/Dialect/ControlFlow.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_CONTROLFLOW_H #define MLIR_C_DIALECT_CONTROLFLOW_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Func.h b/mlir/include/mlir-c/Dialect/Func.h index 4bdac4268..eeb6dfe05 100644 --- a/mlir/include/mlir-c/Dialect/Func.h +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -18,7 +18,7 @@ #ifndef MLIR_C_DIALECT_FUNC_H #define MLIR_C_DIALECT_FUNC_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h index cf4e899c5..1a18d82c0 100644 --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_GPU_H #define MLIR_C_DIALECT_GPU_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 2cf73a359..ba98c33fd 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_LLVM_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 2fe1872be..0ab201e15 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_LINALG_H #define MLIR_C_DIALECT_LINALG_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/PDL.h b/mlir/include/mlir-c/Dialect/PDL.h index 8bd7976e2..6ad2e2da6 100644 --- a/mlir/include/mlir-c/Dialect/PDL.h +++ b/mlir/include/mlir-c/Dialect/PDL.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_PDL_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index eb529b845..39a17318c 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -11,7 +11,6 @@ #define MLIR_C_DIALECT_QUANT_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SCF.h b/mlir/include/mlir-c/Dialect/SCF.h index c1b256779..75f1b6839 100644 --- a/mlir/include/mlir-c/Dialect/SCF.h +++ b/mlir/include/mlir-c/Dialect/SCF.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SCF_H #define MLIR_C_DIALECT_SCF_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Shape.h b/mlir/include/mlir-c/Dialect/Shape.h index f64da8016..3fe3ddf5c 100644 --- a/mlir/include/mlir-c/Dialect/Shape.h +++ b/mlir/include/mlir-c/Dialect/Shape.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_SHAPE_H #define MLIR_C_DIALECT_SHAPE_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 68d72b917..252ec6864 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -11,7 +11,7 @@ #define MLIR_C_DIALECT_SPARSETENSOR_H #include "mlir-c/AffineMap.h" -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/Dialect/Tensor.h b/mlir/include/mlir-c/Dialect/Tensor.h index f74978248..74cbc5a6f 100644 --- a/mlir/include/mlir-c/Dialect/Tensor.h +++ b/mlir/include/mlir-c/Dialect/Tensor.h @@ -10,7 +10,7 @@ #ifndef MLIR_C_DIALECT_TENSOR_H #define MLIR_C_DIALECT_TENSOR_H -#include "mlir-c/Registration.h" +#include "mlir-c/IR.h" #ifdef __cplusplus extern "C" { diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index edd0e44ec..2d38700c2 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -130,6 +130,11 @@ MLIR_CAPI_EXPORTED MlirDialect mlirContextGetOrLoadDialect(MlirContext context, MLIR_CAPI_EXPORTED void mlirContextEnableMultithreading(MlirContext context, bool enable); +/// Eagerly loads all available dialects registered with a context, making +/// them available for use for IR construction. +MLIR_CAPI_EXPORTED void +mlirContextLoadAllAvailableDialects(MlirContext context); + /// Returns whether the given fully-qualified operation (i.e. /// 'dialect.operation') is registered with the context. This will return true /// if the dialect is loaded and the operation is registered within the @@ -157,6 +162,47 @@ MLIR_CAPI_EXPORTED bool mlirDialectEqual(MlirDialect dialect1, /// Returns the namespace of the given dialect. MLIR_CAPI_EXPORTED MlirStringRef mlirDialectGetNamespace(MlirDialect dialect); +//===----------------------------------------------------------------------===// +// DialectHandle API. +// Registration entry-points for each dialect are declared using the common +// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect +// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", +// "tensor", "linalg"). The following declarations are produced: +// +// /// Gets the above hook methods in struct form for a dialect by namespace. +// /// This is intended to facilitate dynamic lookup and registration of +// /// dialects via a plugin facility based on shared library symbol lookup. +// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); +// +// This is done via a common macro to facilitate future expansion to +// registration schemes. +//===----------------------------------------------------------------------===// + +struct MlirDialectHandle { + const void *ptr; +}; +typedef struct MlirDialectHandle MlirDialectHandle; + +#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() + +/// Returns the namespace associated with the provided dialect handle. +MLIR_CAPI_EXPORTED +MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); + +/// Inserts the dialect associated with the provided dialect handle into the +/// provided dialect registry +MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, + MlirDialectRegistry); + +/// Registers the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, + MlirContext); + +/// Loads the dialect associated with the provided dialect handle. +MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, + MlirContext); + //===----------------------------------------------------------------------===// // DialectRegistry API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index cb6112954..b66bdfe02 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -15,7 +15,6 @@ #define MLIR_C_PASS_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #include "mlir-c/Support.h" #ifdef __cplusplus diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h new file mode 100644 index 000000000..b98ce154d --- /dev/null +++ b/mlir/include/mlir-c/RegisterEverything.h @@ -0,0 +1,38 @@ +//===-- mlir-c/RegisterEverything.h - Register all MLIR entities --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This header contains registration entry points for MLIR upstream dialects +// and passes. Downstream projects typically will not want to use this unless +// if they don't care about binary size or build bloat and just wish access +// to the entire set of upstream facilities. For those that do care, they +// should use registration functions specific to their project. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REGISTER_EVERYTHING_H +#define MLIR_C_REGISTER_EVERYTHING_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Appends all upstream dialects and extensions to the dialect registry. +MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirDialectRegistry registry); + +/// Register all translations to LLVM IR for dialects that can support it. +MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); + +/// Register all compiler passes of MLIR. +MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_REGISTER_EVERYTHING_H diff --git a/mlir/include/mlir-c/Registration.h b/mlir/include/mlir-c/Registration.h deleted file mode 100644 index ab37866d9..000000000 --- a/mlir/include/mlir-c/Registration.h +++ /dev/null @@ -1,75 +0,0 @@ -//===-- mlir-c/Registration.h - Registration functions for MLIR ---*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_REGISTRATION_H -#define MLIR_C_REGISTRATION_H - -#include "mlir-c/IR.h" - -#ifdef __cplusplus -extern "C" { -#endif - -//===----------------------------------------------------------------------===// -// Dialect registration declarations. -// Registration entry-points for each dialect are declared using the common -// MLIR_DECLARE_DIALECT_REGISTRATION_CAPI macro, which takes the dialect -// API name (i.e. "Func", "Tensor", "Linalg") and namespace (i.e. "func", -// "tensor", "linalg"). The following declarations are produced: -// -// /// Gets the above hook methods in struct form for a dialect by namespace. -// /// This is intended to facilitate dynamic lookup and registration of -// /// dialects via a plugin facility based on shared library symbol lookup. -// const MlirDialectHandle *mlirGetDialectHandle__{NAMESPACE}__(); -// -// This is done via a common macro to facilitate future expansion to -// registration schemes. -//===----------------------------------------------------------------------===// - -struct MlirDialectHandle { - const void *ptr; -}; -typedef struct MlirDialectHandle MlirDialectHandle; - -#define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() - -/// Returns the namespace associated with the provided dialect handle. -MLIR_CAPI_EXPORTED -MlirStringRef mlirDialectHandleGetNamespace(MlirDialectHandle); - -/// Inserts the dialect associated with the provided dialect handle into the -/// provided dialect registry -MLIR_CAPI_EXPORTED void mlirDialectHandleInsertDialect(MlirDialectHandle, - MlirDialectRegistry); - -/// Registers the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED void mlirDialectHandleRegisterDialect(MlirDialectHandle, - MlirContext); - -/// Loads the dialect associated with the provided dialect handle. -MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, - MlirContext); - -/// Registers all dialects known to core MLIR with the provided Context. -/// This is needed before creating IR for these Dialects. -/// TODO: Remove this function once the real registration API is finished. -MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirContext context); - -/// Register all translations to LLVM IR for dialects that can support it. -MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); - -/// Register all compiler passes of MLIR. -MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_REGISTRATION_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 588721c3f..351fb964e 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -124,6 +124,25 @@ struct type_caster { } }; +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirDialectRegistry, _("MlirDialectRegistry")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle cast(MlirDialectRegistry v, return_value_policy, handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + /// Casts object <-> MlirLocation. template <> struct type_caster { diff --git a/mlir/include/mlir/CAPI/Registration.h b/mlir/include/mlir/CAPI/Registration.h index e57023d30..355c4bcfe 100644 --- a/mlir/include/mlir/CAPI/Registration.h +++ b/mlir/include/mlir/CAPI/Registration.h @@ -10,7 +10,6 @@ #define MLIR_CAPI_REGISTRATION_H #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" diff --git a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp b/mlir/lib/Bindings/Python/AllPassesRegistration.cpp deleted file mode 100644 index f595b20ba..000000000 --- a/mlir/lib/Bindings/Python/AllPassesRegistration.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- AllPassesRegistration.cpp - Pybind module to register all passes ---===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Registration.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirAllPassesRegistration, m) { - m.doc() = "MLIR All Passes Convenience Module"; - - // Register all passes on load. - mlirRegisterAllPasses(); -} diff --git a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp b/mlir/lib/Bindings/Python/Conversions/Conversions.cpp deleted file mode 100644 index c9d380178..000000000 --- a/mlir/lib/Bindings/Python/Conversions/Conversions.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Conversions.cpp - Pybind module for the Conversionss library -------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Conversion.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirConversions, m) { - m.doc() = "MLIR Conversions library"; - - // Register all the passes in the Conversions library on load. - mlirRegisterConversionPasses(); -} diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 98566ddea..973835182 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -16,7 +16,7 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/IR.h" -#include "mlir-c/Registration.h" +//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -474,7 +474,6 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) { PyMlirContext *PyMlirContext::createNewContextForInit() { MlirContext context = mlirContextCreate(); - mlirRegisterAllDialects(context); return new PyMlirContext(context); } @@ -793,7 +792,7 @@ py::tuple PyDiagnostic::getNotes() { } //------------------------------------------------------------------------------ -// PyDialect, PyDialectDescriptor, PyDialects +// PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ MlirDialect PyDialects::getDialectForKey(const std::string &key, @@ -807,6 +806,19 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, return dialect; } +py::object PyDialectRegistry::getCapsule() { + return py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(*this)); +} + +PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { + MlirDialectRegistry rawRegistry = + mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + if (mlirDialectRegistryIsNull(rawRegistry)) + throw py::error_already_set(); + return PyDialectRegistry(rawRegistry); +} + //------------------------------------------------------------------------------ // PyLocation //------------------------------------------------------------------------------ @@ -2207,8 +2219,11 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of MlirContext. + // Note that this is exported as _BaseContext. The containing, Python level + // __init__.py will subclass it with site-specific functionality and set a + // "Context" attribute on this module. //---------------------------------------------------------------------------- - py::class_(m, "Context", py::module_local()) + py::class_(m, "_BaseContext", py::module_local()) .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", @@ -2276,7 +2291,16 @@ void mlir::python::populateIRCore(py::module &m) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")); + py::arg("operation_name")) + .def( + "append_dialect_registry", + [](PyMlirContext &self, PyDialectRegistry ®istry) { + mlirContextAppendDialectRegistry(self.get(), registry); + }, + py::arg("registry")) + .def("load_all_available_dialects", [](PyMlirContext &self) { + mlirContextLoadAllAvailableDialects(self.get()); + }); //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor @@ -2331,6 +2355,15 @@ void mlir::python::populateIRCore(py::module &m) { clazz.attr("__name__") + py::str(")>"); }); + //---------------------------------------------------------------------------- + // Mapping of PyDialectRegistry + //---------------------------------------------------------------------------- + py::class_(m, "DialectRegistry", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyDialectRegistry::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) + .def(py::init<>()); + //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 371157a56..2e2ebaa27 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -390,6 +390,32 @@ class PyDialect { pybind11::object descriptor; }; +/// Wrapper around an MlirDialectRegistry. +/// Upon construction, the Python wrapper takes ownership of the +/// underlying MlirDialectRegistry. +class PyDialectRegistry { +public: + PyDialectRegistry() : registry(mlirDialectRegistryCreate()) {} + PyDialectRegistry(MlirDialectRegistry registry) : registry(registry) {} + ~PyDialectRegistry() { + if (!mlirDialectRegistryIsNull(registry)) + mlirDialectRegistryDestroy(registry); + } + PyDialectRegistry(PyDialectRegistry &) = delete; + PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { + other.registry = {nullptr}; + } + + operator MlirDialectRegistry() const { return registry; } + MlirDialectRegistry get() const { return registry; } + + pybind11::object getCapsule(); + static PyDialectRegistry createFromCapsule(pybind11::object capsule); + +private: + MlirDialectRegistry registry; +}; + /// Wrapper around an MlirLocation. class PyLocation : public BaseContextObject { public: diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp new file mode 100644 index 000000000..fed5c36a6 --- /dev/null +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -0,0 +1,26 @@ +//===- RegisterEverything.cpp - API to register all dialects/passes -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/RegisterEverything.h" +#include "mlir-c/Conversion.h" +#include "mlir-c/Transforms.h" + +#include "mlir/Bindings/Python/PybindAdaptors.h" + +PYBIND11_MODULE(_mlirRegisterEverything, m) { + m.doc() = "MLIR All Upstream Dialects and Passes Registration"; + + m.def("register_dialects", [](MlirDialectRegistry registry) { + mlirRegisterAllDialects(registry); + }); + + // Register all passes on load. + mlirRegisterAllPasses(); + mlirRegisterConversionPasses(); + mlirRegisterTransformsPasses(); +} diff --git a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp b/mlir/lib/Bindings/Python/Transforms/Transforms.cpp deleted file mode 100644 index 944b191bc..000000000 --- a/mlir/lib/Bindings/Python/Transforms/Transforms.cpp +++ /dev/null @@ -1,22 +0,0 @@ -//===- Transforms.cpp - Pybind module for the Transforms library ----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Transforms.h" - -#include - -// ----------------------------------------------------------------------------- -// Module initialization. -// ----------------------------------------------------------------------------- - -PYBIND11_MODULE(_mlirTransforms, m) { - m.doc() = "MLIR Transforms library"; - - // Register all the passes in the Transforms library on load. - mlirRegisterTransformsPasses(); -} diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 393b49ecb..ffb04c287 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -12,7 +12,7 @@ add_subdirectory(Dialect) add_subdirectory(Conversion) add_subdirectory(Interfaces) add_subdirectory(IR) -add_subdirectory(Registration) +add_subdirectory(RegisterEverything) add_subdirectory(Transforms) # Only enable the ExecutionEngine if the native target is configured in. diff --git a/mlir/lib/CAPI/Conversion/CMakeLists.txt b/mlir/lib/CAPI/Conversion/CMakeLists.txt index 166e79916..8cafc09d3 100644 --- a/mlir/lib/CAPI/Conversion/CMakeLists.txt +++ b/mlir/lib/CAPI/Conversion/CMakeLists.txt @@ -2,6 +2,9 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIConversion Passes.cpp + DEPENDS + MLIRConversionPassIncGen + LINK_LIBS PUBLIC ${conversion_libs} ) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6959ad322..c931ea774 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -77,6 +77,10 @@ void mlirContextEnableMultithreading(MlirContext context, bool enable) { return unwrap(context)->enableMultithreading(enable); } +void mlirContextLoadAllAvailableDialects(MlirContext context) { + unwrap(context)->loadAllAvailableDialects(); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Registration/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt similarity index 76% rename from mlir/lib/CAPI/Registration/CMakeLists.txt rename to mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 67e26a50f..942bba84e 100644 --- a/mlir/lib/CAPI/Registration/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -2,13 +2,15 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -add_mlir_upstream_c_api_library(MLIRCAPIRegistration - Registration.cpp +add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything + RegisterEverything.cpp LINK_LIBS PUBLIC - MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation ${dialect_libs} ${translation_libs} ${conversion_libs} + + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRCAPITransforms ) diff --git a/mlir/lib/CAPI/Registration/Registration.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp similarity index 69% rename from mlir/lib/CAPI/Registration/Registration.cpp rename to mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp index 4ac300d1f..25a1a216c 100644 --- a/mlir/lib/CAPI/Registration/Registration.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -1,4 +1,4 @@ -//===- Registration.cpp - C Interface for MLIR Registration ---------------===// +//===- RegisterEverything.cpp - Register all MLIR entities ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,17 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Registration.h" +#include "mlir-c/RegisterEverything.h" #include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" -void mlirRegisterAllDialects(MlirContext context) { - mlir::registerAllDialects(*unwrap(context)); - // TODO: we may not want to eagerly load here. - unwrap(context)->loadAllAvailableDialects(); +void mlirRegisterAllDialects(MlirDialectRegistry registry) { + mlir::registerAllDialects(*unwrap(registry)); } void mlirRegisterAllLLVMTranslations(MlirContext context) { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index dc831c9cb..1fbbadf26 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -37,17 +37,8 @@ declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine runtime/*.py ) -declare_mlir_python_sources(MLIRPythonSources.Passes - ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - ADD_TO_PARENT MLIRPythonSources - SOURCES_GLOB - all_passes_registration/*.py - conversions/*.py - transforms/*.py -) - declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources - ROOT_DIR "${MLIR_MAIN_INCLUDE_DIR}" + ROOT_DIR "${MLIR_SOURCE_DIR}/include" SOURCES_GLOB "mlir-c/*.h" ) @@ -283,12 +274,31 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MLIRCAPIDebug MLIRCAPIIR MLIRCAPIInterfaces - MLIRCAPIRegistration # TODO: See about dis-aggregating # Dialects MLIRCAPIFunc ) +# This extension exposes an API to register all dialects, extensions, and passes +# packaged in upstream MLIR and it is used for the upstream "mlir" Python +# package. Downstreams will likely want to provide their own and not depend +# on this one, since it links in the world. +# Note that this is not added to any top-level source target for transitive +# inclusion: It must be included explicitly by downstreams if desired. Note that +# this has a very large impact on what gets built/packaged. +declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything + MODULE_NAME _mlirRegisterEverything + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + RegisterEverything.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIConversion + MLIRCAPITransforms + MLIRCAPIRegisterEverything +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg @@ -341,18 +351,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.AllPassesRegistration - MODULE_NAME _mlirAllPassesRegistration - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - AllPassesRegistration.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion - MLIRCAPITransforms -) - declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect @@ -365,18 +363,6 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MLIRCAPIAsync ) -declare_mlir_python_extension(MLIRPythonExtension.Conversions - MODULE_NAME _mlirConversions - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Conversions/Conversions.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIConversion -) - # Only enable the ExecutionEngine if the native target is configured in. if(TARGET ${LLVM_NATIVE_ARCH}) declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine @@ -428,18 +414,6 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MLIRCAPISparseTensor ) -declare_mlir_python_extension(MLIRPythonExtension.Transforms - MODULE_NAME _mlirTransforms - ADD_TO_PARENT MLIRPythonSources.Passes - ROOT_DIR "${PYTHON_SOURCE_DIR}" - SOURCES - Transforms/Transforms.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPITransforms -) - # TODO: Figure out how to put this in the test tree. # This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration @@ -506,7 +480,7 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI MLIRPythonCAPI.HeaderSources DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration + MLIRPythonExtension.RegisterEverything ${_ADDL_TEST_SOURCES} ) @@ -520,7 +494,7 @@ add_mlir_python_modules(MLIRPythonModules INSTALL_PREFIX "python_packages/mlir_core/mlir" DECLARED_SOURCES MLIRPythonSources - MLIRPythonExtension.AllPassesRegistration + MLIRPythonExtension.RegisterEverything ${_ADDL_TEST_SOURCES} COMMON_CAPI_LINK_LIBS MLIRPythonCAPI diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 23bc50267..add8d92ee 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -9,12 +9,6 @@ _this_dir = os.path.dirname(__file__) -# These submodules have no type stubs and are thus opaque to the type checker. -_mlirConversions: Any -_mlirTransforms: Any -_mlirAllPassesRegistration: Any - - def get_lib_dirs() -> Sequence[str]: """Gets the lib directory for linking to shared libraries. @@ -31,3 +25,77 @@ def get_include_dirs() -> Sequence[str]: not be present. """ return [os.path.join(_this_dir, "include")] + + +# Perform Python level site initialization. This involves: +# 1. Attempting to load initializer modules, specific to the distribution. +# 2. Defining the concrete mlir.ir.Context that does site specific +# initialization. +# +# Aside from just being far more convenient to do this at the Python level, +# it is actually quite hard/impossible to have such __init__ hooks, given +# the pybind memory model (i.e. there is not a Python reference to the object +# in the scope of the base class __init__). +# +# For #1, we: +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and +# proceeding so long as a module with the name is found. +# b. If the module has a 'register_dialects' attribute, it will be called +# immediately with a DialectRegistry to populate. +# c. If the module has a 'context_init_hook', it will be added to a list +# of callbacks that are invoked as the last step of Context +# initialization (and passed the Context under construction). +# +# This facility allows downstreams to customize Context creation to their +# needs. +def _site_initialize(): + import importlib + import itertools + import logging + from ._mlir import ir + registry = ir.DialectRegistry() + post_init_hooks = [] + + def process_initializer_module(module_name): + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + + logging.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logging.debug("Registering dialects from initializer %r", m) + m.register_dialects(registry) + if hasattr(m, "context_init_hook"): + logging.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + return True + + + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + process_initializer_module("_mlirRegisterEverything") + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(registry) + for hook in post_init_hooks: + hook(self) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + + ir.Context = Context + + +_site_initialize() diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index cd7eb0a55..60bc3676f 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -479,6 +479,11 @@ class Context: def d(self) -> Dialects: ... @property def dialects(self) -> Dialects: ... + def append_dialect_registry(self, registry: "DialectRegistry") -> None: ... + def load_all_available_dialects(self) -> None: ... + +class DialectRegistry: + def __init__(self) -> None: ... # TODO: Auto-generated. Audit and fix. class DenseElementsAttr(Attribute): diff --git a/mlir/python/mlir/all_passes_registration/__init__.py b/mlir/python/mlir/all_passes_registration/__init__.py deleted file mode 100644 index aca557ab9..000000000 --- a/mlir/python/mlir/all_passes_registration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -from .._mlir_libs import _mlirAllPassesRegistration as _cextAllPasses diff --git a/mlir/python/mlir/conversions/__init__.py b/mlir/python/mlir/conversions/__init__.py deleted file mode 100644 index a6a9eb821..000000000 --- a/mlir/python/mlir/conversions/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirConversions as _cextConversions diff --git a/mlir/python/mlir/transforms/__init__.py b/mlir/python/mlir/transforms/__init__.py deleted file mode 100644 index 71ea17d7f..000000000 --- a/mlir/python/mlir/transforms/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -# Expose the corresponding C-Extension module with a well-known name at this -# level. -from .._mlir_libs import _mlirTransforms as _cextTransforms From 32188551e0a083ddddce5c4cd1ca0069c8eccffc Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Wed, 20 Jul 2022 19:12:41 +0000 Subject: [PATCH 319/915] [MLIR] Add function to create Float16 array attribute This patch adds a new function mlirDenseElementsAttrFloat16Get(), which accepts the shaped type, the number of Float16 values, and a pointer to an array of Float16 values, each of which is a uint16_t value. This commit is repeating https://reviews.llvm.org/D123981 + #761 but for Float16 Differential Revision: https://reviews.llvm.org/D130069 --- mlir/include/mlir-c/BuiltinAttributes.h | 2 ++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index ce4514094..050408d1f 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -381,6 +381,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrDoubleGet( MlirType shapedType, intptr_t numElements, const double *elements); MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrBFloat16Get( MlirType shapedType, intptr_t numElements, const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrFloat16Get( + MlirType shapedType, intptr_t numElements, const uint16_t *elements); /// Creates a dense elements attribute with the given shaped type from string /// elements. diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 759b70895..ba3481ae1 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -479,6 +479,13 @@ MlirAttribute mlirDenseElementsAttrBFloat16Get(MlirType shapedType, const void *buffer = static_cast(elements); return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); } +MlirAttribute mlirDenseElementsAttrFloat16Get(MlirType shapedType, + intptr_t numElements, + const uint16_t *elements) { + size_t bufferSize = numElements * 2; + const void *buffer = static_cast(elements); + return mlirDenseElementsAttrRawBufferGet(shapedType, bufferSize, buffer); +} MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, intptr_t numElements, From 9dcc9b83b9b8846abd8107263e9a8e977c6dc503 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 21 Jul 2022 14:00:37 +0000 Subject: [PATCH 320/915] [mlir][python] Fix issues with block argument slices The type extraction helper function for block argument and op result list objects was ignoring the slice entirely. So was the slice addition. Both are caused by a misleading naming convention to implement slices via CRTP. Make the convention more explicit and hide the helper functions so users have harder time calling them directly. Closes #56540. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D130271 --- mlir/lib/Bindings/Python/IRAffine.cpp | 18 ++++--- mlir/lib/Bindings/Python/IRCore.cpp | 75 ++++++++++++++------------ mlir/lib/Bindings/Python/PybindUtils.h | 47 +++++++++++----- 3 files changed, 88 insertions(+), 52 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 0da936e85..fc7133b43 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -385,9 +385,13 @@ class PyAffineMapExprList step), affineMap(map) {} - intptr_t getNumElements() { return mlirAffineMapGetNumResults(affineMap); } +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirAffineMapGetNumResults(affineMap); } - PyAffineExpr getElement(intptr_t pos) { + PyAffineExpr getRawElement(intptr_t pos) { return PyAffineExpr(affineMap.getContext(), mlirAffineMapGetResult(affineMap, pos)); } @@ -397,7 +401,6 @@ class PyAffineMapExprList return PyAffineMapExprList(affineMap, startIndex, length, step); } -private: PyAffineMap affineMap; }; } // namespace @@ -460,9 +463,13 @@ class PyIntegerSetConstraintList step), set(set) {} - intptr_t getNumElements() { return mlirIntegerSetGetNumConstraints(set); } +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { return mlirIntegerSetGetNumConstraints(set); } - PyIntegerSetConstraint getElement(intptr_t pos) { + PyIntegerSetConstraint getRawElement(intptr_t pos) { return PyIntegerSetConstraint(set, pos); } @@ -471,7 +478,6 @@ class PyIntegerSetConstraintList return PyIntegerSetConstraintList(set, startIndex, length, step); } -private: PyIntegerSet set; }; } // namespace diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 973835182..fea26ec66 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1968,8 +1968,8 @@ template static std::vector getValueTypes(Container &container, PyMlirContextRef &context) { std::vector result; - result.reserve(container.getNumElements()); - for (int i = 0, e = container.getNumElements(); i < e; ++i) { + result.reserve(container.size()); + for (int i = 0, e = container.size(); i < e; ++i) { result.push_back( PyType(context, mlirValueGetType(container.getElement(i).get()))); } @@ -1993,14 +1993,24 @@ class PyBlockArgumentList step), operation(std::move(operation)), block(block) {} + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyBlockArgumentList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + /// Returns the number of arguments in the list. - intptr_t getNumElements() { + intptr_t getRawNumElements() { operation->checkValid(); return mlirBlockGetNumArguments(block); } - /// Returns `pos`-the element in the list. Asserts on out-of-bounds. - PyBlockArgument getElement(intptr_t pos) { + /// Returns `pos`-the element in the list. + PyBlockArgument getRawElement(intptr_t pos) { MlirValue argument = mlirBlockGetArgument(block, pos); return PyBlockArgument(operation, argument); } @@ -2011,13 +2021,6 @@ class PyBlockArgumentList return PyBlockArgumentList(operation, block, startIndex, length, step); } - static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyBlockArgumentList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - } - -private: PyOperationRef operation; MlirBlock block; }; @@ -2038,12 +2041,25 @@ class PyOpOperandList : public Sliceable { step), operation(operation) {} - intptr_t getNumElements() { + void dunderSetItem(intptr_t index, PyValue value) { + index = wrapIndex(index); + mlirOperationSetOperand(operation->get(), index, value.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpOperandList::dunderSetItem); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumOperands(operation->get()); } - PyValue getElement(intptr_t pos) { + PyValue getRawElement(intptr_t pos) { MlirValue operand = mlirOperationGetOperand(operation->get(), pos); MlirOperation owner; if (mlirValueIsAOpResult(operand)) @@ -2061,16 +2077,6 @@ class PyOpOperandList : public Sliceable { return PyOpOperandList(operation, startIndex, length, step); } - void dunderSetItem(intptr_t index, PyValue value) { - index = wrapIndex(index); - mlirOperationSetOperand(operation->get(), index, value.get()); - } - - static void bindDerived(ClassTy &c) { - c.def("__setitem__", &PyOpOperandList::dunderSetItem); - } - -private: PyOperationRef operation; }; @@ -2090,12 +2096,22 @@ class PyOpResultList : public Sliceable { step), operation(operation) {} - intptr_t getNumElements() { + static void bindDerived(ClassTy &c) { + c.def_property_readonly("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumResults(operation->get()); } - PyOpResult getElement(intptr_t index) { + PyOpResult getRawElement(intptr_t index) { PyValue value(operation, mlirOperationGetResult(operation->get(), index)); return PyOpResult(value); } @@ -2104,13 +2120,6 @@ class PyOpResultList : public Sliceable { return PyOpResultList(operation, startIndex, length, step); } - static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - } - -private: PyOperationRef operation; }; diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index e791ba8e2..5356cbd54 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -199,15 +199,17 @@ struct PySinglePartStringAccumulator { /// A derived class must provide the following: /// - a `static const char *pyClassName ` field containing the name of the /// Python class to bind; -/// - an instance method `intptr_t getNumElements()` that returns the number +/// - an instance method `intptr_t getRawNumElements()` that returns the +/// number /// of elements in the backing container (NOT that of the slice); -/// - an instance method `ElementTy getElement(intptr_t)` that returns a -/// single element at the given index. +/// - an instance method `ElementTy getRawElement(intptr_t)` that returns a +/// single element at the given linear index (NOT slice index); /// - an instance method `Derived slice(intptr_t, intptr_t, intptr_t)` that /// constructs a new instance of the derived pseudo-container with the /// given slice parameters (to be forwarded to the Sliceable constructor). /// -/// The getNumElements() and getElement(intptr_t) callbacks must not throw. +/// The getRawNumElements() and getRawElement(intptr_t) callbacks must not +/// throw. /// /// A derived class may additionally define: /// - a `static void bindDerived(ClassTy &)` method to bind additional methods @@ -217,8 +219,8 @@ class Sliceable { protected: using ClassTy = pybind11::class_; - // Transforms `index` into a legal value to access the underlying sequence. - // Returns <0 on failure. + /// Transforms `index` into a legal value to access the underlying sequence. + /// Returns <0 on failure. intptr_t wrapIndex(intptr_t index) { if (index < 0) index = length + index; @@ -227,6 +229,15 @@ class Sliceable { return index; } + /// Computes the linear index given the current slice properties. + intptr_t linearizeIndex(intptr_t index) { + intptr_t linearIndex = index * step + startIndex; + assert(linearIndex >= 0 && + linearIndex < static_cast(this)->getRawNumElements() && + "linear index out of bounds, the slice is ill-formed"); + return linearIndex; + } + /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. @@ -238,13 +249,8 @@ class Sliceable { return {}; } - // Compute the linear index given the current slice properties. - int linearIndex = index * step + startIndex; - assert(linearIndex >= 0 && - linearIndex < static_cast(this)->getNumElements() && - "linear index out of bounds, the slice is ill-formed"); return pybind11::cast( - static_cast(this)->getElement(linearIndex)); + static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given @@ -266,6 +272,21 @@ class Sliceable { assert(length >= 0 && "expected non-negative slice length"); } + /// Returns the `index`-th element in the slice, supports negative indices. + /// Throws if the index is out of bounds. + ElementTy getElement(intptr_t index) { + // Negative indices mean we count from the end. + index = wrapIndex(index); + if (index < 0) { + throw pybind11::index_error("index out of range"); + } + + return static_cast(this)->getRawElement(linearizeIndex(index)); + } + + /// Returns the size of slice. + intptr_t size() { return length; } + /// Returns a new vector (mapped to Python list) containing elements from two /// slices. The new vector is necessary because slices may not be contiguous /// or even come from the same original sequence. @@ -276,7 +297,7 @@ class Sliceable { elements.push_back(static_cast(this)->getElement(i)); } for (intptr_t i = 0; i < other.length; ++i) { - elements.push_back(static_cast(this)->getElement(i)); + elements.push_back(static_cast(&other)->getElement(i)); } return elements; } From 64d0c560c35c31ffde6e9e19879415e5b087b515 Mon Sep 17 00:00:00 2001 From: George Petterson Date: Thu, 21 Jul 2022 14:36:47 -0400 Subject: [PATCH 321/915] Adding a new variant of DepthwiseConv2D This is the same as the existing multiplier-1 variant of DepthwiseConv2D, but in PyTorch dim order. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D128575 --- .../linalg/opdsl/ops/core_named_ops.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 7dd3f9495..b22e6c3b1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -482,6 +482,32 @@ def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) +@linalg_structured_op +def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KH, S.KW), + O=TensorDef(U, + S.N, + S.IC, + S.OH, + S.OW, + output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, + S.DW, + default=[1, 1])): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) + + @linalg_structured_op def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), From d4e4baa8a998eeb387365ac537bdcfa96be220d6 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Sat, 16 Jul 2022 16:41:33 -0400 Subject: [PATCH 322/915] [mlir][python] Fix issue in diagnostic note initialization Previously the elements of the notes tuple would be invalid objects when accessed from a diagnostic handler, resulting in a segfault when used. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D129943 --- mlir/lib/Bindings/Python/IRCore.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index fea26ec66..beb0c6cb9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -785,8 +785,7 @@ py::tuple PyDiagnostic::getNotes() { materializedNotes = py::tuple(numNotes); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - py::object pyNoteDiag = py::cast(PyDiagnostic(noteDiag)); - PyTuple_SET_ITEM(materializedNotes->ptr(), i, pyNoteDiag.ptr()); + materializedNotes.value()[i] = PyDiagnostic(noteDiag); } return *materializedNotes; } From a85d2099cdc6218af93754b044ffb50be6616dad Mon Sep 17 00:00:00 2001 From: River Riddle Date: Sun, 10 Jul 2022 01:00:21 -0700 Subject: [PATCH 323/915] [mlir] Refactor the Parser library in preparation for an MLIR binary format The current Parser library is solely focused on providing API for the textual MLIR format, but MLIR will soon also provide a binary format. This commit renames the current Parser library to AsmParser to better correspond to what the library is actually intended for. A new Parser library is added which will act as a unified parser interface between both text and binary formats. Most parser clients are unaffected, given that the unified interface is essentially the same as the current interface. Only clients that rely on utilizing the AsmParserState, or those that want to parse Attributes/Types need to be updated to point to the AsmParser library. Differential Revision: https://reviews.llvm.org/D129605 --- mlir/lib/CAPI/IR/IR.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c931ea774..da43da1af 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -9,6 +9,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/AsmParser/AsmParser.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" From 11e0418f3a8865bdeaf2dcea96163540fcf39d18 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Mon, 18 Jul 2022 21:32:38 -0700 Subject: [PATCH 324/915] [mlir] Remove types from attributes This patch removes the `type` field from `Attribute` along with the `Attribute::getType` accessor. Going forward, this means that attributes in MLIR will no longer have types as a first-class concept. This patch lays the groundwork to incrementally remove or refactor code that relies on generic attributes being typed. The immediate impact will be on attributes that rely on `Attribute` containing a type, such as `IntegerAttr`, `DenseElementsAttr`, and `ml_program::ExternAttr`, which will now need to define a type parameter on their storage classes. This will save memory as all other attribute kinds will no longer contain a type. Moreover, it will not be possible to generically query the type of an attribute directly. This patch provides an attribute interface `TypedAttr` that implements only one method, `getType`, which can be used to generically query the types of attributes that implement the interface. This interface can be used to retain the concept of a "typed attribute". The ODS-generated accessor for a `type` parameter automatically implements this method. Next steps will be to refactor the assembly formats of certain operations that rely on `parseAttribute(type)` and `printAttributeWithoutType` to remove special handling of type elision until `type` can be removed from the dialect parsing hook entirely; and incrementally remove uses of `TypedAttr`. Reviewed By: lattner, rriddle, jpienaar Differential Revision: https://reviews.llvm.org/D130092 --- mlir/lib/CAPI/IR/IR.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index da43da1af..435c974e7 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -753,7 +753,10 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) { } MlirType mlirAttributeGetType(MlirAttribute attribute) { - return wrap(unwrap(attribute).getType()); + Attribute attr = unwrap(attribute); + if (auto typedAttr = attr.dyn_cast()) + return wrap(typedAttr.getType()); + return wrap(NoneType::get(attr.getContext())); } MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { From b779ddec33dc996d0c9c4d3a7ef6d5d2e6c6a9ce Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 1 Aug 2022 08:52:41 +0000 Subject: [PATCH 325/915] Fix MLIR Python binding for arith.constant after argument has been changed to an interface 11e0418f3a88 removed the Type field from attributes and arith::ConstantOp argument is now a TypedAttrInterface which isn't supported by the python generator. This patch temporarily restore the functionality for arith.constant but won't generalize: we need to work on the generator instead. Differential Revision: https://reviews.llvm.org/D130878 --- mlir/python/mlir/dialects/_arith_ops_ext.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py index c755df255..240859352 100644 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -60,6 +60,10 @@ def create_index(cls, value: int, *, loc=None, ip=None): def type(self): return self.results[0].type + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + @property def literal_value(self) -> Union[int, float]: if _is_integer_like_type(self.type): From 53c61beaccf8cc7f1e319efc120b835b97e156b7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 15 Jul 2022 20:08:32 -0700 Subject: [PATCH 326/915] [mlir] Remove OpaqueElementsAttr This attribute is technical debt from the early stages of MLIR, before ElementsAttr was an interface and when it was more difficult for dialects to define their own types of attributes. At present it isn't used at all in tree (aside from being convenient for eliding other ElementsAttr), and has had little to no evolution in the past three years. Differential Revision: https://reviews.llvm.org/D129917 --- mlir/include/mlir-c/BuiltinAttributes.h | 9 --------- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 8 -------- 2 files changed, 17 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 050408d1f..62ee31904 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -455,15 +455,6 @@ mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr); -//===----------------------------------------------------------------------===// -// Opaque elements attribute. -//===----------------------------------------------------------------------===// - -// TODO: expose Dialect to the bindings and implement accessors here. - -/// Checks whether the given attribute is an opaque elements attribute. -MLIR_CAPI_EXPORTED bool mlirAttributeIsAOpaqueElements(MlirAttribute attr); - //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index ba3481ae1..afab9458d 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -598,14 +598,6 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { unwrap(attr).cast().getRawData().data()); } -//===----------------------------------------------------------------------===// -// Opaque elements attribute. -//===----------------------------------------------------------------------===// - -bool mlirAttributeIsAOpaqueElements(MlirAttribute attr) { - return unwrap(attr).isa(); -} - //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// From 6aa73505c8e513d4f66b2bdf95ccb454cad18a7f Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 2 Aug 2022 10:36:05 -0700 Subject: [PATCH 327/915] [mlir][sparse] remove singleton dimension level type (for now) Although we have plans to support this, and many other, dimension level type(s), currently the tag is not supported. It will be easy to add this back once support is added. NOTE: based on discussion in https://discourse.llvm.org/t/overcoming-sparsification-limitation-on-level-types/62585 https://github.com/llvm/llvm-project/issues/51658 Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D131002 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 7 ++----- mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 3 +-- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 4 +--- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 252ec6864..ac2b8b60f 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -20,12 +20,10 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// Dimension level types that define sparse tensors: -/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE - dimension is dense, every +/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE - dimension is dense, every /// entry is stored /// - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED - dimension is sparse, -/// only nonzeros are stored. -/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON - dimension contains single -/// coordinate, no siblings. +/// only nonzeros are stored (no duplicates). /// /// These correspond to SparseTensorEncodingAttr::DimLevelType in the C++ API. /// If updating, keep them in sync and update the static_assert in the impl @@ -33,7 +31,6 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); enum MlirSparseTensorDimLevelType { MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE, MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON, }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index b24d024d1..49b4a8998 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -18,8 +18,7 @@ using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "DimLevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) - .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) - .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON); + .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index f35c14af2..b7b2fd5d8 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -25,9 +25,7 @@ static_assert( static_cast(SparseTensorEncodingAttr::DimLevelType::Dense) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == static_cast( - SparseTensorEncodingAttr::DimLevelType::Compressed) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == - static_cast(SparseTensorEncodingAttr::DimLevelType::Singleton), + SparseTensorEncodingAttr::DimLevelType::Compressed), "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { From b3c99f568efc0158c5ea9ce7f9398a4090e571e9 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Wed, 3 Aug 2022 15:28:49 +0200 Subject: [PATCH 328/915] [MLIR] Fix checks for native arch Using if (TARGET ${LLVM_NATIVE_ARCH}) only works if MLIR is built together with LLVM, but not for standalone builds of MLIR. The correct way to check this is if (${LLVM_NATIVE_ARCH} IN_LIST LLVM_TARGETS_TO_BUILD), as the LLVM build system exports LLVM_TARGETS_TO_BUILD. To avoid repeating the same check many times, add a MLIR_ENABLE_EXECUTION_ENGINE variable. Differential Revision: https://reviews.llvm.org/D131071 --- mlir/lib/CAPI/CMakeLists.txt | 3 +-- mlir/python/CMakeLists.txt | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index ffb04c287..052eff327 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -15,8 +15,7 @@ add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) -# Only enable the ExecutionEngine if the native target is configured in. -if(TARGET ${LLVM_NATIVE_ARCH}) +if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) endif() diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1fbbadf26..7eb6e05e4 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -363,8 +363,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MLIRCAPIAsync ) -# Only enable the ExecutionEngine if the native target is configured in. -if(TARGET ${LLVM_NATIVE_ARCH}) +if(MLIR_ENABLE_EXECUTION_ENGINE) declare_mlir_python_extension(MLIRPythonExtension.ExecutionEngine MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine From 68d00db477f246113fa4db1b41ec85adb2bdc6f5 Mon Sep 17 00:00:00 2001 From: John Demme Date: Sat, 6 Aug 2022 21:58:46 -0700 Subject: [PATCH 329/915] [MLIR] Add MlirValue to PybindAdapters Allows out-of-tree users to automatically cast to/from MlirValue. --- .../mlir/Bindings/Python/PybindAdaptors.h | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 351fb964e..564425b9b 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -208,6 +208,27 @@ struct type_caster { }; }; +/// Casts object <-> MlirValue. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirValue, _("MlirValue")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToValue(capsule.ptr()); + return !mlirValueIsNull(value); + } + static handle cast(MlirValue v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonValueToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + /// Casts object -> MlirPassManager. template <> struct type_caster { From 53352bee687cc87c70e307b1ad41a4f6c7c281c3 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 7 Aug 2022 15:28:18 -0700 Subject: [PATCH 330/915] [mlir][python] Address deprecation warning for hasValue --- mlir/lib/Bindings/Python/IRModule.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 2e2ebaa27..246b244e1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -330,7 +330,7 @@ class PyDiagnosticHandler { PyDiagnosticHandler(MlirContext context, pybind11::object callback); ~PyDiagnosticHandler(); - bool isAttached() { return registeredID.hasValue(); } + bool isAttached() { return registeredID.has_value(); } bool getHadError() { return hadError; } /// Detaches the handler. Does nothing if not attached. From ab6dee9ee64a458a49588a1f260ac7bb2b9e172f Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 8 Aug 2022 19:14:44 +0200 Subject: [PATCH 331/915] [mlir][math] Fix pythong bindings after 00f7096d31cc7896ffd490e65104d264923f0df5 --- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index cc99081b4..e7493617a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -393,7 +393,7 @@ def _unary_log(self, x: Value) -> Value: def _unary_abs(self, x: Value) -> Value: if _is_floating_point_type(x.type): - return math.AbsOp(x).result + return math.AbsFOp(x).result raise NotImplementedError("Unsupported 'abs' operand: {x}") def _unary_ceil(self, x: Value) -> Value: From d777f09c438a976ecfced7f25cedcdb13828bddd Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 9 Aug 2022 19:37:04 -0700 Subject: [PATCH 332/915] [MLIR] [Python] Fix `Value.owner` to handle BlockArgs Previously, calling `Value.owner()` would C++ assert in debug builds if `Value` was a block argument. Additionally, the behavior was just wrong in release builds. This patch adds support for BlockArg Values. --- mlir/lib/Bindings/Python/IRCore.cpp | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index beb0c6cb9..db199b38b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3118,11 +3118,22 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly( "owner", [](PyValue &self) { - assert(mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in " - "the IR"); - return self.getParentOperation().getObject(); + MlirValue v = self.get(); + if (mlirValueIsAOpResult(v)) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in " + "the IR"); + return self.getParentOperation().getObject(); + } + + if (mlirValueIsABlockArgument(v)) { + MlirBlock block = mlirBlockArgumentGetOwner(self.get()); + return py::cast(PyBlock(self.getParentOperation(), block)); + } + + assert(false && "Value must be a block argument or an op result"); }) .def("__eq__", [](PyValue &self, PyValue &other) { From 4f1fc07c17f82222f81dfa8476901de5c8638216 Mon Sep 17 00:00:00 2001 From: John Demme Date: Tue, 9 Aug 2022 20:07:33 -0700 Subject: [PATCH 333/915] [MLIR] [Python] Fix the Windows build broken by d777f09 Windows builds require all control paths return. Since we don't have `llvm_unreachable` in the Python bindings, just return `None`. --- mlir/lib/Bindings/Python/IRCore.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index db199b38b..033edbdfd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3117,7 +3117,7 @@ void mlir::python::populateIRCore(py::module &m) { kDumpDocstring) .def_property_readonly( "owner", - [](PyValue &self) { + [](PyValue &self) -> py::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert( @@ -3134,6 +3134,7 @@ void mlir::python::populateIRCore(py::module &m) { } assert(false && "Value must be a block argument or an op result"); + return py::none(); }) .def("__eq__", [](PyValue &self, PyValue &other) { From b31ea75de65fea3a1d246123487d52fbf5bfa198 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 12 Aug 2022 13:53:46 +0000 Subject: [PATCH 334/915] [mlir][transform] failure propagation mode in sequence Introduce two different failure propagation mode in the Transform dialect's Sequence operation. These modes specify whether silenceable errors produced by nested ops are immediately propagated, thus stopping the sequence, or suppressed. The latter is useful in end-to-end transform application scenarios where the user cannot correct the transformation, but it is robust enough to silenceable failures. It can be combined with the "alternatives" operation. There is intentionally no default value to avoid favoring one mode over the other. Downstreams can update their tests using: S='s/sequence \(%.*\) {/sequence \1 failures(propagate) {/' T='s/sequence {/sequence failures(propagate) {/' git grep -l transform.sequence | xargs sed -i -e "$S" git grep -l transform.sequence | xargs sed -i -e "$T" Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D131774 --- .../mlir/dialects/_transform_ops_ext.py | 20 ++++++++++++++----- .../mlir/dialects/transform/__init__.py | 15 ++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index e75d6b5f9..992139f72 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -9,6 +9,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +from argparse import SUPPRESS from typing import Optional, overload, Sequence, Union @@ -78,22 +79,31 @@ def __init__(self, class SequenceOp: @overload - def __init__(self, resultsOrRoot: Sequence[Type], + def __init__(self, failure_propagation_mode, + resultsOrRoot: Sequence[Type], optionalRoot: Optional[Union[Operation, Value]]): ... @overload - def __init__(self, resultsOrRoot: Optional[Union[Operation, Value]], - optionalRoot: NoneType): + def __init__(self, failure_propagation_mode, + resultsOrRoot: Optional[Union[Operation, + Value]], optionalRoot: NoneType): ... - def __init__(self, resultsOrRoot=None, optionalRoot=None): + def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=None): results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] root = ( resultsOrRoot if not isinstance(resultsOrRoot, Sequence) else optionalRoot) root = _get_op_result_or_value(root) if root else None - super().__init__(results_=results, root=root) + if not isinstance(failure_propagation_mode, Attribute): + failure_propagation_mode_attr = IntegerAttr.get( + IntegerType.get_signless(32), failure_propagation_mode._as_int()) + else: + failure_propagation_mode = failure_propagation_mode + super().__init__(results_=results, + failure_propagation_mode=failure_propagation_mode_attr, + root=root) self.regions[0].blocks.append(pdl.OperationType.get()) @property diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index ab4fa5631..d4d71274c 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -2,4 +2,19 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from enum import Enum + + +class FailurePropagationMode(Enum): + """Propagation mode for silenceable errors.""" + PROPAGATE = 1 + SUPPRESS = 2 + + def _as_int(self): + if self is FailurePropagationMode.PROPAGATE: + return 1 + + assert self is FailurePropagationMode.SUPPRESS + return 2 + from .._transform_ops_gen import * From 9666200357c1d0681ae6c15c8c0f25d12aa5f33b Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 12 Aug 2022 15:41:42 -0400 Subject: [PATCH 335/915] [mlir][python] Add python bindings for DenseArrayAttr This patch adds python bindings for the dense array variants. Fixes #56975 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D131801 --- mlir/include/mlir-c/BuiltinAttributes.h | 55 +++++++ mlir/lib/Bindings/Python/IRAttributes.cpp | 171 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 100 +++++++++++++ 3 files changed, 326 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 62ee31904..c75db95b4 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -296,6 +296,61 @@ mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs); /// shaped type and use its sizes to build a multi-dimensional index. MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +/// Checks whether the given attribute is a dense array attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI16Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI64Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF32Array(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseF64Array(MlirAttribute attr); + +/// Create a dense array attribute with the given elements. +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, + intptr_t size, + int const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, + intptr_t size, + int8_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, + intptr_t size, + int16_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, + intptr_t size, + int32_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, + intptr_t size, + int64_t const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, + intptr_t size, + float const *values); +MLIR_CAPI_EXPORTED MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, + intptr_t size, + double const *values); + +/// Get the size of a dense array. +MLIR_CAPI_EXPORTED intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr); + +/// Get an element of a dense array. +MLIR_CAPI_EXPORTED bool mlirDenseBoolArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED float mlirDenseF32ArrayGetElement(MlirAttribute attr, + intptr_t pos); +MLIR_CAPI_EXPORTED double mlirDenseF64ArrayGetElement(MlirAttribute attr, + intptr_t pos); + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 1093d50c8..d8fc568b7 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -110,6 +110,161 @@ static T pyTryCast(py::handle object) { } } +/// A python-wrapped dense array attribute with an element type and a derived +/// implementation class. +template +class PyDenseArrayAttribute + : public PyConcreteAttribute> { +public: + static constexpr typename PyConcreteAttribute< + PyDenseArrayAttribute>::IsAFunctionTy isaFunction = + DerivedT::isaFunction; + static constexpr const char *pyClassName = DerivedT::pyClassName; + using PyConcreteAttribute< + PyDenseArrayAttribute>::PyConcreteAttribute; + + /// Iterator over the integer elements of a dense array. + class PyDenseArrayIterator { + public: + PyDenseArrayIterator(PyAttribute attr) : attr(attr) {} + + /// Return a copy of the iterator. + PyDenseArrayIterator dunderIter() { return *this; } + + /// Return the next element. + EltTy dunderNext() { + // Throw if the index has reached the end. + if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) + throw py::stop_iteration(); + return DerivedT::getElement(attr.get(), nextIndex++); + } + + /// Bind the iterator class. + static void bind(py::module &m) { + py::class_(m, DerivedT::pyIteratorName, + py::module_local()) + .def("__iter__", &PyDenseArrayIterator::dunderIter) + .def("__next__", &PyDenseArrayIterator::dunderNext); + } + + private: + /// The referenced dense array attribute. + PyAttribute attr; + /// The next index to read. + int nextIndex = 0; + }; + + /// Get the element at the given index. + EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } + + /// Bind the attribute class. + static void bindDerived(typename PyConcreteAttribute< + PyDenseArrayAttribute>::ClassTy &c) { + // Bind the constructor. + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return PyDenseArrayAttribute(ctx->getRef(), attr); + }, + py::arg("values"), py::arg("context") = py::none(), + "Gets a uniqued dense array attribute"); + // Bind the array methods. + c.def("__getitem__", + [](PyDenseArrayAttribute &arr, intptr_t i) { + if (i >= mlirDenseArrayGetNumElements(arr)) + throw py::index_error("DenseArray index out of range"); + return arr.getItem(i); + }); + c.def("__len__", [](const PyDenseArrayAttribute &arr) { + return mlirDenseArrayGetNumElements(arr); + }); + c.def("__iter__", [](const PyDenseArrayAttribute &arr) { + return PyDenseArrayIterator(arr); + }); + // Bind a concat. + c.def("__add__", [](PyDenseArrayAttribute &arr, + py::list extras) { + std::vector values; + intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); + values.reserve(numOldElements + py::len(extras)); + for (intptr_t i = 0; i < numOldElements; ++i) + values.push_back(arr.getItem(i)); + for (py::handle attr : extras) + values.push_back(pyTryCast(attr)); + MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), + values.size(), values.data()); + return PyDenseArrayAttribute(arr.getContext(), attr); + }); + } +}; + +/// Instantiate the python dense array classes. +struct PyDenseBoolArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; + static constexpr auto getAttribute = mlirDenseBoolArrayGet; + static constexpr auto getElement = mlirDenseBoolArrayGetElement; + static constexpr const char *pyClassName = "DenseBoolArrayAttr"; + static constexpr const char *pyIteratorName = "DenseBoolArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI8ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI8Array; + static constexpr auto getAttribute = mlirDenseI8ArrayGet; + static constexpr auto getElement = mlirDenseI8ArrayGetElement; + static constexpr const char *pyClassName = "DenseI8ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI8ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI16ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI16Array; + static constexpr auto getAttribute = mlirDenseI16ArrayGet; + static constexpr auto getElement = mlirDenseI16ArrayGetElement; + static constexpr const char *pyClassName = "DenseI16ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI16ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI32Array; + static constexpr auto getAttribute = mlirDenseI32ArrayGet; + static constexpr auto getElement = mlirDenseI32ArrayGetElement; + static constexpr const char *pyClassName = "DenseI32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseI64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseI64Array; + static constexpr auto getAttribute = mlirDenseI64ArrayGet; + static constexpr auto getElement = mlirDenseI64ArrayGetElement; + static constexpr const char *pyClassName = "DenseI64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseI64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF32ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF32Array; + static constexpr auto getAttribute = mlirDenseF32ArrayGet; + static constexpr auto getElement = mlirDenseF32ArrayGetElement; + static constexpr const char *pyClassName = "DenseF32ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF32ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; +struct PyDenseF64ArrayAttribute + : public PyDenseArrayAttribute { + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseF64Array; + static constexpr auto getAttribute = mlirDenseF64ArrayGet; + static constexpr auto getElement = mlirDenseF64ArrayGetElement; + static constexpr const char *pyClassName = "DenseF64ArrayAttr"; + static constexpr const char *pyIteratorName = "DenseF64ArrayIterator"; + using PyDenseArrayAttribute::PyDenseArrayAttribute; +}; + class PyArrayAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; @@ -891,6 +1046,22 @@ class PyUnitAttribute : public PyConcreteAttribute { void mlir::python::populateIRAttributes(py::module &m) { PyAffineMapAttribute::bind(m); + + PyDenseBoolArrayAttribute::bind(m); + PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI8ArrayAttribute::bind(m); + PyDenseI8ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI16ArrayAttribute::bind(m); + PyDenseI16ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI32ArrayAttribute::bind(m); + PyDenseI32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseI64ArrayAttribute::bind(m); + PyDenseI64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF32ArrayAttribute::bind(m); + PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); + PyDenseF64ArrayAttribute::bind(m); + PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index afab9458d..c50096bb1 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -311,6 +311,106 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { return unwrap(attr).cast().getNumElements(); } +//===----------------------------------------------------------------------===// +// Dense array attribute. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// IsA support. + +bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} +bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +//===----------------------------------------------------------------------===// +// Constructors. + +MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, + int const *values) { + SmallVector elements(values, values + size); + return wrap(DenseBoolArrayAttr::get(unwrap(ctx), elements)); +} +MlirAttribute mlirDenseI8ArrayGet(MlirContext ctx, intptr_t size, + int8_t const *values) { + return wrap( + DenseI8ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI16ArrayGet(MlirContext ctx, intptr_t size, + int16_t const *values) { + return wrap( + DenseI16ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI32ArrayGet(MlirContext ctx, intptr_t size, + int32_t const *values) { + return wrap( + DenseI32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseI64ArrayGet(MlirContext ctx, intptr_t size, + int64_t const *values) { + return wrap( + DenseI64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF32ArrayGet(MlirContext ctx, intptr_t size, + float const *values) { + return wrap( + DenseF32ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} +MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, + double const *values) { + return wrap( + DenseF64ArrayAttr::get(unwrap(ctx), ArrayRef(values, size))); +} + +//===----------------------------------------------------------------------===// +// Accessors. + +intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { + return unwrap(attr).cast().size(); +} + +//===----------------------------------------------------------------------===// +// Indexed accessors. + +bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} +double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast()[pos]; +} + //===----------------------------------------------------------------------===// // Dense elements attribute. //===----------------------------------------------------------------------===// From 3cd8ef796a63afca95035d4dbaa0a76be95b1813 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 12 Aug 2022 15:43:03 -0400 Subject: [PATCH 336/915] (Reland) [mlir] Switch segment size attributes to DenseI32ArrayAttr This reland includes changes to the Python bindings. Switch variadic operand and result segment size attributes to use the dense i32 array. Dense integer arrays were introduced primarily to represent index lists. They are a better fit for segment sizes than dense elements attrs. Depends on D131801 Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D131803 --- mlir/lib/Bindings/Python/IRCore.cpp | 18 ++++++++---------- mlir/python/mlir/dialects/_ods_common.py | 4 ++-- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 033edbdfd..e83e99305 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1285,8 +1285,8 @@ py::object PyOpView::buildGeneric( py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - std::vector operandSegmentLengths; - std::vector resultSegmentLengths; + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; // Validate/determine region count. auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); @@ -1497,20 +1497,18 @@ py::object PyOpView::buildGeneric( // Add result_segment_sizes attribute. if (!resultSegmentLengths.empty()) { - int64_t size = resultSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), - resultSegmentLengths.size(), resultSegmentLengths.data()); + MlirAttribute segmentLengthAttr = + mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), + resultSegmentLengths.data()); (*attributes)["result_segment_sizes"] = PyAttribute(context, segmentLengthAttr); } // Add operand_segment_sizes attribute. if (!operandSegmentLengths.empty()) { - int64_t size = operandSegmentLengths.size(); - MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get( - mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)), - operandSegmentLengths.size(), operandSegmentLengths.data()); + MlirAttribute segmentLengthAttr = + mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), + operandSegmentLengths.data()); (*attributes)["operand_segment_sizes"] = PyAttribute(context, segmentLengthAttr); } diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 0c66593ce..51b900819 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -78,11 +78,11 @@ def segmented_accessor(elements, raw_segments, idx): Returns a slice of elements corresponding to the idx-th segment. elements: a sliceable container (operands or results). - raw_segments: an mlir.ir.Attribute, of DenseIntElements subclass containing + raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing sizes of the segments. idx: index of the segment. """ - segments = _cext.ir.DenseIntElementsAttr(raw_segments) + segments = _cext.ir.DenseI32ArrayAttr(raw_segments) start = sum(segments[i] for i in range(idx)) end = start + segments[idx] return elements[start:end] From dbdc1856051024618905af96d3a2bfade7528205 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Fri, 12 Aug 2022 23:35:38 -0400 Subject: [PATCH 337/915] [mlir][python] add a todo to replace throw in dense array iterator --- mlir/lib/Bindings/Python/IRAttributes.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index d8fc568b7..cb59893b4 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -183,7 +183,6 @@ class PyDenseArrayAttribute c.def("__iter__", [](const PyDenseArrayAttribute &arr) { return PyDenseArrayIterator(arr); }); - // Bind a concat. c.def("__add__", [](PyDenseArrayAttribute &arr, py::list extras) { std::vector values; @@ -278,9 +277,9 @@ class PyArrayAttribute : public PyConcreteAttribute { PyArrayAttributeIterator &dunderIter() { return *this; } PyAttribute dunderNext() { - if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) { + // TODO: Throw is an inefficient way to stop iteration. + if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw py::stop_iteration(); - } return PyAttribute(attr.getContext(), mlirArrayAttrGetElement(attr.get(), nextIndex++)); } From 0da987299a84efcefa8b63b33648dc0f79b59a33 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 15 Aug 2022 13:02:17 +0200 Subject: [PATCH 338/915] [mlir][linalg][python] Add named constructor for MatchOp This constructor makes it easier to match for ops by their name. Differential Revision: https://reviews.llvm.org/D131882 --- .../dialects/_structured_transform_ops_ext.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 95bf2cc99..faf98ef49 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -110,6 +110,24 @@ def __init__(self, ip=ip) +class MatchOp: + """Specialization for MatchOp class.""" + + @classmethod + def match_op_names(MatchOp, + target: Union[Operation, Value], + names: Sequence[str], + loc=None, + ip=None): + pdl_operation_type = pdl.OperationType.get() + return MatchOp( + pdl_operation_type, + _get_op_result_or_value(target), + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip) + + class MultiTileSizesOp: """Specialization for MultitileSizesOp class.""" From 973ef3a36cf9eb8d8a00e7d58797e0163f20b799 Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Mon, 15 Aug 2022 13:14:15 -0400 Subject: [PATCH 339/915] [mlir][python] Fix build on windows Reviewed By: stella.stamenova, ashay-github Differential Revision: https://reviews.llvm.org/D131906 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 38 +++++++++-------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index cb59893b4..f9c7f6fe5 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -113,15 +113,9 @@ static T pyTryCast(py::handle object) { /// A python-wrapped dense array attribute with an element type and a derived /// implementation class. template -class PyDenseArrayAttribute - : public PyConcreteAttribute> { +class PyDenseArrayAttribute : public PyConcreteAttribute { public: - static constexpr typename PyConcreteAttribute< - PyDenseArrayAttribute>::IsAFunctionTy isaFunction = - DerivedT::isaFunction; - static constexpr const char *pyClassName = DerivedT::pyClassName; - using PyConcreteAttribute< - PyDenseArrayAttribute>::PyConcreteAttribute; + using PyConcreteAttribute::PyConcreteAttribute; /// Iterator over the integer elements of a dense array. class PyDenseArrayIterator { @@ -158,33 +152,29 @@ class PyDenseArrayAttribute EltTy getItem(intptr_t i) { return DerivedT::getElement(*this, i); } /// Bind the attribute class. - static void bindDerived(typename PyConcreteAttribute< - PyDenseArrayAttribute>::ClassTy &c) { + static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { // Bind the constructor. c.def_static( "get", [](const std::vector &values, DefaultingPyMlirContext ctx) { MlirAttribute attr = DerivedT::getAttribute(ctx->get(), values.size(), values.data()); - return PyDenseArrayAttribute(ctx->getRef(), attr); + return DerivedT(ctx->getRef(), attr); }, py::arg("values"), py::arg("context") = py::none(), "Gets a uniqued dense array attribute"); // Bind the array methods. - c.def("__getitem__", - [](PyDenseArrayAttribute &arr, intptr_t i) { - if (i >= mlirDenseArrayGetNumElements(arr)) - throw py::index_error("DenseArray index out of range"); - return arr.getItem(i); - }); - c.def("__len__", [](const PyDenseArrayAttribute &arr) { - return mlirDenseArrayGetNumElements(arr); + c.def("__getitem__", [](DerivedT &arr, intptr_t i) { + if (i >= mlirDenseArrayGetNumElements(arr)) + throw py::index_error("DenseArray index out of range"); + return arr.getItem(i); }); - c.def("__iter__", [](const PyDenseArrayAttribute &arr) { - return PyDenseArrayIterator(arr); + c.def("__len__", [](const DerivedT &arr) { + return mlirDenseArrayGetNumElements(arr); }); - c.def("__add__", [](PyDenseArrayAttribute &arr, - py::list extras) { + c.def("__iter__", + [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); + c.def("__add__", [](DerivedT &arr, py::list extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); values.reserve(numOldElements + py::len(extras)); @@ -194,7 +184,7 @@ class PyDenseArrayAttribute values.push_back(pyTryCast(attr)); MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), values.size(), values.data()); - return PyDenseArrayAttribute(arr.getContext(), attr); + return DerivedT(arr.getContext(), attr); }); } }; From 2f8bd041ebbc491f83f883b9455bb1ff1e5fcdf2 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 20 Aug 2022 21:18:28 -0700 Subject: [PATCH 340/915] Remove redundant initialization of Optional (NFC) --- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 3f1155fe1..66d3bb187 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -46,7 +46,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( if (!info) return mlirLogicalResultFailure(); - llvm::Optional maybeLocation = llvm::None; + llvm::Optional maybeLocation; if (!mlirLocationIsNull(location)) maybeLocation = unwrap(location); SmallVector unwrappedOperands; From e8ae43db1f542b814ed563c737489ab9c2275829 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 29 Aug 2022 10:06:17 +0000 Subject: [PATCH 341/915] Apply clang-tidy fixes for performance-unnecessary-value-param in IRAttributes.cpp (NFC) --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index f9c7f6fe5..8d8cea395 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -120,7 +120,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { /// Iterator over the integer elements of a dense array. class PyDenseArrayIterator { public: - PyDenseArrayIterator(PyAttribute attr) : attr(attr) {} + PyDenseArrayIterator(PyAttribute attr) : attr(std::move(attr)) {} /// Return a copy of the iterator. PyDenseArrayIterator dunderIter() { return *this; } @@ -174,7 +174,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { }); c.def("__iter__", [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, py::list extras) { + c.def("__add__", [](DerivedT &arr, const py::list &extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); values.reserve(numOldElements + py::len(extras)); From 57843684eb50cb49d476380a43b08f3f51c626a4 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Mon, 29 Aug 2022 15:43:20 -0700 Subject: [PATCH 342/915] [mlir][sparse] add more dimension level types and properties We recently removed the singleton dimension level type (see the revision https://reviews.llvm.org/D131002) since it was unimplemented but also incomplete (properties were missing). This revision add singleton back as extra dimension level type, together with properties ordered/not-ordered and unique/not-unique. Even though still not lowered to actual code, this provides a complete way of defining many more sparse storage schemes (in the long run, we want to support even dimension level types and properties using the additional extensions proposed in [Chou]). Note that the current solution of using suffixes for the properties is not ideal, but keeps the extension relatively simple with respect to parsing and printing. Furthermore, it is rather consistent with the TACO implementation which uses things like Compressed-Unique as well. Nevertheless, we probably want to separate dimension level types from properties when we add more types and properties. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D132897 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 14 +++++++---- .../Bindings/Python/DialectSparseTensor.cpp | 9 +++++++- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 23 ++++++++++++++++++- 3 files changed, 39 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index ac2b8b60f..9465f36c3 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -19,11 +19,8 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); -/// Dimension level types that define sparse tensors: -/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE - dimension is dense, every -/// entry is stored -/// - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED - dimension is sparse, -/// only nonzeros are stored (no duplicates). +/// Dimension level types (and properties) that define sparse tensors. +/// See the documentation in SparseTensorAttrDefs.td for their meaning. /// /// These correspond to SparseTensorEncodingAttr::DimLevelType in the C++ API. /// If updating, keep them in sync and update the static_assert in the impl @@ -31,6 +28,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); enum MlirSparseTensorDimLevelType { MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE, MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED, + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU, + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO, + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO, + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON, + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU, + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO, + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO, }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 49b4a8998..ae9cfbb67 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -18,7 +18,14 @@ using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "DimLevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) - .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED); + .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) + .value("compressed-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) + .value("compressed-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) + .value("compressed-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) + .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) + .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) + .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) + .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index b7b2fd5d8..e3e32900c 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -25,7 +25,28 @@ static_assert( static_cast(SparseTensorEncodingAttr::DimLevelType::Dense) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == static_cast( - SparseTensorEncodingAttr::DimLevelType::Compressed), + SparseTensorEncodingAttr::DimLevelType::Compressed) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::CompressedNu) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::CompressedNo) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::CompressedNuNo) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::Singleton) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::SingletonNu) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::SingletonNo) && + static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) == + static_cast( + SparseTensorEncodingAttr::DimLevelType::SingletonNuNo), "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { From a89208212d99ecdbb830e7fec7ad58befa4eaf0c Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 25 Aug 2022 16:21:28 -0700 Subject: [PATCH 343/915] [mlir] Make DenseArrayAttr generic This patch turns `DenseArrayBaseAttr` into a fully-functional attribute by adding a generic parser and printer, supporting bool or integer and floating point element types with bitwidths divisible by 8. It has been renamed to `DenseArrayAttr`. The patch maintains the specialized subclasses, e.g. `DenseI32ArrayAttr`, which remain the preferred API for accessing elements in C++. This allows `DenseArrayAttr` to hold signed and unsigned integer elements: ``` array array ``` "Exotic" floating point elements: ``` array ``` And integers of other bitwidths: ``` array ``` Reviewed By: rriddle, lattner Differential Revision: https://reviews.llvm.org/D132758 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index c50096bb1..b02484ab9 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -383,7 +383,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, // Accessors. intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().size(); + return unwrap(attr).cast().size(); } //===----------------------------------------------------------------------===// From 39aa492328c432128f33035c2904c876b815e136 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 5 Sep 2022 11:54:19 +0000 Subject: [PATCH 344/915] Plumb write_bytecode to the Python API This adds a `write_bytecode` method to the Operation class. The method takes a file handle and writes the binary blob to it. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D133210 --- mlir/include/mlir-c/IR.h | 5 +++++ mlir/lib/Bindings/Python/IRCore.cpp | 17 +++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 3 +++ mlir/lib/CAPI/IR/CMakeLists.txt | 1 + mlir/lib/CAPI/IR/IR.cpp | 8 +++++++- 5 files changed, 33 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 2d38700c2..daf097da2 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -521,6 +521,11 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); +/// Same as mlirOperationPrint but writing the bytecode format out. +MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData); + /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e83e99305..389969baa 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -119,6 +119,13 @@ static const char kOperationGetAsmDocstring[] = argument. )"; +static const char kOperationPrintBytecodeDocstring[] = + R"(Write the bytecode form of the operation to a file like object. + +Args: + file: The file like object to write to. +)"; + static const char kOperationStrDunderDocstring[] = R"(Gets the assembly form of the operation with default options. @@ -1022,6 +1029,14 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } +void PyOperationBase::writeBytecode(py::object fileObject) { + PyOperation &operation = getOperation(); + operation.checkValid(); + PyFileAccumulator accum(fileObject, /*binary=*/true); + mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); +} + py::object PyOperationBase::getAsm(bool binary, llvm::Optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, @@ -2627,6 +2642,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. py::arg("binary") = false, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 246b244e1..ad783c6c3 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -512,6 +512,9 @@ class PyOperationBase { bool printGenericOpForm, bool useLocalScope, bool assumeVerified); + // Implement the bound 'writeBytecode' method. + void writeBytecode(pybind11::object fileObject); + /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/CMakeLists.txt b/mlir/lib/CAPI/IR/CMakeLists.txt index 320ed0718..36f28520d 100644 --- a/mlir/lib/CAPI/IR/CMakeLists.txt +++ b/mlir/lib/CAPI/IR/CMakeLists.txt @@ -12,6 +12,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIIR Support.cpp LINK_LIBS PUBLIC + MLIRBytecodeWriter MLIRIR MLIRParser MLIRSupport diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 435c974e7..98a3ff348 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -10,6 +10,7 @@ #include "mlir-c/Support.h" #include "mlir/AsmParser/AsmParser.h" +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" @@ -23,7 +24,6 @@ #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" -#include "llvm/Support/Debug.h" #include using namespace mlir; @@ -485,6 +485,12 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } +void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + writeBytecodeToFile(unwrap(op), stream); +} + void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } bool mlirOperationVerify(MlirOperation op) { From 7b9b194ac97780991c312c80a7952e81829f756b Mon Sep 17 00:00:00 2001 From: Stanley Winata Date: Thu, 8 Sep 2022 19:44:30 -0400 Subject: [PATCH 345/915] [mlir][linalg] add conv_1d_ncw_fcw Reviewed By: hanchung, antiagainst Differential Revision: https://reviews.llvm.org/D133465 --- .../linalg/opdsl/ops/core_named_ops.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b22e6c3b1..983842cde 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -269,6 +269,26 @@ def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( U, K[D.kw, D.c, D.f]) +@linalg_structured_op +def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KW), + O=TensorDef(U, S.N, S.F, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs 1-D convolution. + + Layout: + * Input: NCW. + * Kernel: FCW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.ow, D.c, D.kw) + O[D.n, D.f, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( + U, K[D.f, D.c, D.kw]) @linalg_structured_op def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, From 87b40eb0039a34125aa3c46e2d6e235a7f30c20d Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 15 Sep 2022 04:12:58 -0700 Subject: [PATCH 346/915] [mlir][Linalg] Post submit addressed comments missed in f0cdc5bcd3f25192f12bfaff072ce02497b59c3c Differential Revision: https://reviews.llvm.org/D133936 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index faf98ef49..eddc384f6 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -287,7 +287,7 @@ def __init__(self, ip=None): pdl_operation_type = pdl.OperationType.get() if isinstance(vectorize_padding, bool): - vectorize_padding = BoolAttr.get(vectorize_padding) + vectorize_padding = UnitAttr.get() super().__init__( pdl_operation_type, _get_op_result_or_value(target), From 6d1ff5ef26a49095d4e180caae2875e33d793a69 Mon Sep 17 00:00:00 2001 From: lorenzo chelini Date: Mon, 19 Sep 2022 12:11:04 +0200 Subject: [PATCH 347/915] [MLIR][Linalg] introduce batch-reduce GEMM The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor blocks (which form a batch) and the partial multiplication results are reduced into a single output tensor block. See: https://ieeexplore.ieee.org/document/9139809 for more details. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134163 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 983842cde..b9b292d84 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +@linalg_structured_op +def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( + U, B[D.b, D.k, D.n]) @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), From d8fd7437933936aad110159745e2208cf141efff Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Mon, 19 Sep 2022 12:17:30 +0200 Subject: [PATCH 348/915] Revert "[MLIR][Linalg] introduce batch-reduce GEMM" This reverts commit 6d1ff5ef26a49095d4e180caae2875e33d793a69. --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b9b292d84..983842cde 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,20 +150,6 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) -@linalg_structured_op -def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a batch-reduce matrix multiplication of two 3D inputs. - The partial multiplication results are reduced into a 2D output. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( - U, B[D.b, D.k, D.n]) @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), From ebf099a4fd4b31dc48731b5e696d7185c8586188 Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Mon, 19 Sep 2022 12:34:49 +0200 Subject: [PATCH 349/915] [MLIR][Linalg] introduce batch-reduce GEMM The batch-reduce GEMM kernel essentially multiplies a sequence of input tensor blocks (which form a batch) and the partial multiplication results are reduced into a single output tensor block. See: https://ieeexplore.ieee.org/document/9139809 for more details. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D134163 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 983842cde..1aa112dcf 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,20 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +@linalg_structured_op +def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True)): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( + U, B[D.b, D.k, D.n])) @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), From e92b4ae141e64fca143e48e8d39c8cad7182589d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 21 Sep 2022 13:40:47 -0700 Subject: [PATCH 350/915] [mlir] Flip PDL to use Both accessors This allows for incrementally updating the old API usages without needing to update everything at once. PDL will be left on Both for a little bit and then flipped to prefixed when all APIs have been updated. Differential Revision: https://reviews.llvm.org/D134387 --- mlir/python/mlir/dialects/_pdl_ops_ext.py | 38 ++++++++++++----------- 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index bb63fe64d..428301b18 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -86,14 +86,14 @@ class AttributeOp: """Specialization for PDL attribute op class.""" def __init__(self, - type: Optional[Union[OpView, Operation, Value]] = None, + valueType: Optional[Union[OpView, Operation, Value]] = None, value: Optional[Attribute] = None, *, loc=None, ip=None): - type = type if type is None else _get_value(type) + valueType = valueType if valueType is None else _get_value(valueType) result = pdl.AttributeType.get() - super().__init__(result, type=type, value=value, loc=loc, ip=ip) + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) class EraseOp: @@ -118,7 +118,7 @@ def __init__(self, ip=None): type = type if type is None else _get_value(type) result = pdl.ValueType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(result, valueType=type, loc=loc, ip=ip) class OperandsOp: @@ -131,7 +131,7 @@ def __init__(self, ip=None): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, type=types, loc=loc, ip=ip) + super().__init__(result, valueType=types, loc=loc, ip=ip) class OperationOp: @@ -147,15 +147,15 @@ def __init__(self, ip=None): name = name if name is None else _get_str_attr(name) args = _get_values(args) - attributeNames = [] - attributeValues = [] + attrNames = [] + attrValues = [] for attrName, attrValue in attributes.items(): - attributeNames.append(StringAttr.get(attrName)) - attributeValues.append(_get_value(attrValue)) - attributeNames = ArrayAttr.get(attributeNames) + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, args, attributeValues, attributeNames, types, name=name, loc=loc, ip=ip) + super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip) class PatternOp: @@ -255,24 +255,26 @@ class TypeOp: """Specialization for PDL type op class.""" def __init__(self, - type: Optional[Union[TypeAttr, Type]] = None, + constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None): - type = type if type is None else _get_type_attr(type) + constantType = constantType if constantType is None else _get_type_attr( + constantType) result = pdl.TypeType.get() - super().__init__(result, type=type, loc=loc, ip=ip) + super().__init__(result, constantType=constantType, loc=loc, ip=ip) class TypesOp: """Specialization for PDL types op class.""" def __init__(self, - types: Sequence[Union[TypeAttr, Type]] = [], + constantTypes: Sequence[Union[TypeAttr, Type]] = [], *, loc=None, ip=None): - types = _get_array_attr([_get_type_attr(ty) for ty in types]) - types = None if not types else types + constantTypes = _get_array_attr( + [_get_type_attr(ty) for ty in constantTypes]) + constantTypes = None if not constantTypes else constantTypes result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, types=types, loc=loc, ip=ip) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) From cea5732af3357d453c10f5cd08c6d759b0ac7a93 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Wed, 28 Sep 2022 14:53:36 +0000 Subject: [PATCH 351/915] [mlir][python] stop initialization on ImportError An `_mlirRegisterEverything.*.so` file from an old build that referenced `MLIRPythonExtension.RegisterEverything`, but which no longer references that extension in a new build, causes runtime errors in the new build like: ImportError: _mlirRegisterEverything.cpython-38-x86_64-linux-gnu.so: undefined symbol: mlirRegisterAllPasses The error occurs because the MLIR Python binding tries to dynamically import the `_mlirRegisterEverything` module but the dynamic importer fails since the new build no longer references `MLIRPythonExtension.RegisterEverything`. One possible solution is for the user to manually remove the `_mlirRegisterEverything.*.so` file. This patch instead resolves the problem in code by printing a waning if the module cannot be imported. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D133450 --- mlir/python/mlir/_mlir_libs/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index add8d92ee..b140ad64e 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -62,6 +62,11 @@ def process_initializer_module(module_name): m = importlib.import_module(f".{module_name}", __name__) except ModuleNotFoundError: return False + except ImportError: + message = (f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function.") + logging.warning(message, exc_info=True) logging.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): From a8f29c4cec13a9ee44e41fb5dff323531f191aaf Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Wed, 28 Sep 2022 13:40:31 +0000 Subject: [PATCH 352/915] [mlir] Add C bindings for StridedArrayAttr Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D134808 --- mlir/include/mlir-c/BuiltinAttributes.h | 24 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 27 +++++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index c75db95b4..b2e32f6d5 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -535,6 +535,30 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +// Checks wheather the given attribute is a strided layout attribute. +MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr); + +// Creates a strided layout attribute from given strides and offset. +MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, + int64_t offset, + intptr_t numStrides, + int64_t *strides); + +// Returns the offset in the given strided layout layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); + +// Returns the number of strides in the given strided layout attribute. +MLIR_CAPI_EXPORTED intptr_t +mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr); + +// Returns the pos-th stride stored in the given strided layout attribute. +MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, + intptr_t pos); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index b02484ab9..1ae2a2b54 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -722,3 +722,30 @@ MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(unwrap(attr).cast().getValues()); } + +//===----------------------------------------------------------------------===// +// Strided layout attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { + return unwrap(attr).isa(); +} + +MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, + intptr_t numStrides, int64_t *strides) { + return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, + ArrayRef(strides, numStrides))); +} + +int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { + return unwrap(attr).cast().getOffset(); +} + +intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { + return static_cast( + unwrap(attr).cast().getStrides().size()); +} + +int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { + return unwrap(attr).cast().getStrides()[pos]; +} From 2aa2f2c4f36249de2bc90c36c5b7db6dc3f62fda Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Thu, 29 Sep 2022 09:41:42 +0000 Subject: [PATCH 353/915] [mlir] Add Python bindings for StridedLayoutAttr Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D134869 --- mlir/include/mlir-c/BuiltinAttributes.h | 7 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 41 +++++++++++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 4 +- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 3 +- .../dialects/_structured_transform_ops_ext.py | 4 +- 5 files changed, 50 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index b2e32f6d5..79f22376e 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -543,10 +543,9 @@ mlirSparseElementsAttrGetValues(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsAStridedLayout(MlirAttribute attr); // Creates a strided layout attribute from given strides and offset. -MLIR_CAPI_EXPORTED MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, - int64_t offset, - intptr_t numStrides, - int64_t *strides); +MLIR_CAPI_EXPORTED MlirAttribute +mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, intptr_t numStrides, + const int64_t *strides); // Returns the offset in the given strided layout layout attribute. MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 8d8cea395..e62f1550c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1031,6 +1031,45 @@ class PyUnitAttribute : public PyConcreteAttribute { } }; +/// Strided layout attribute subclass. +class PyStridedLayoutAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; + static constexpr const char *pyClassName = "StridedLayoutAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](int64_t offset, const std::vector strides, + DefaultingPyMlirContext ctx) { + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), offset, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), + "Gets a strided layout attribute."); + c.def_property_readonly( + "offset", + [](PyStridedLayoutAttribute &self) { + return mlirStridedLayoutAttrGetOffset(self); + }, + "Returns the value of the float point attribute"); + c.def_property_readonly( + "strides", + [](PyStridedLayoutAttribute &self) { + intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); + std::vector strides(size); + for (intptr_t i = 0; i < size; i++) { + strides[i] = mlirStridedLayoutAttrGetStride(self, i); + } + return strides; + }, + "Returns the value of the float point attribute"); + } +}; + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1065,4 +1104,6 @@ void mlir::python::populateIRAttributes(py::module &m) { PyStringAttribute::bind(m); PyTypeAttribute::bind(m); PyUnitAttribute::bind(m); + + PyStridedLayoutAttribute::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 153664d07..379510ce9 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -302,11 +302,11 @@ class PyShapedType : public PyConcreteType { }, "Returns the shape of the ranked shaped type as a list of integers."); c.def_static( - "_get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, "Returns the value used to indicate dynamic dimensions in shaped " "types."); c.def_static( - "_get_dynamic_stride_or_offset", + "get_dynamic_stride_or_offset", []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, "Returns the value used to indicate dynamic strides or offsets in " "shaped types."); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 1ae2a2b54..05ecb0fe8 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -732,7 +732,8 @@ bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, - intptr_t numStrides, int64_t *strides) { + intptr_t numStrides, + const int64_t *strides) { return wrap(StridedLayoutAttr::get(unwrap(ctx), offset, ArrayRef(strides, numStrides))); } diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index eddc384f6..527a8656f 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -211,7 +211,7 @@ def __init__(self, static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType._get_dynamic_size()) + static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) pdl_operation_type = pdl.OperationType.get() @@ -255,7 +255,7 @@ def __init__(self, static_sizes.append(size) else: static_sizes.append( - IntegerAttr.get(i64_type, ShapedType._get_dynamic_size())) + IntegerAttr.get(i64_type, ShapedType.get_dynamic_size())) dynamic_sizes.append(_get_op_result_or_value(size)) sizes_attr = ArrayAttr.get(static_sizes) From 806ffe8fbd2aad2521f6fd697509fca7c5d99af5 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Thu, 29 Sep 2022 11:14:47 -0400 Subject: [PATCH 354/915] [mlir][arith] Change dialect name from Arithmetic to Arith Suggested by @lattner in https://discourse.llvm.org/t/rfc-define-precise-arith-semantics/65507/22. Tested with: `ninja check-mlir check-mlir-integration check-mlir-mlir-spirv-cpu-runner check-mlir-mlir-vulkan-runner check-mlir-examples` and `bazel build --config=generic_clang @llvm-project//mlir:all`. Reviewed By: lattner, Mogball, rriddle, jpienaar, mehdi_amini Differential Revision: https://reviews.llvm.org/D134762 --- mlir/python/CMakeLists.txt | 2 +- .../mlir/dialects/{ArithmeticOps.td => ArithOps.td} | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) rename mlir/python/mlir/dialects/{ArithmeticOps.td => ArithOps.td} (61%) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 7eb6e05e4..fe28cb44b 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -153,7 +153,7 @@ declare_mlir_dialect_python_bindings( declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - TD_FILE dialects/ArithmeticOps.td + TD_FILE dialects/ArithOps.td SOURCES dialects/arith.py dialects/_arith_ops_ext.py diff --git a/mlir/python/mlir/dialects/ArithmeticOps.td b/mlir/python/mlir/dialects/ArithOps.td similarity index 61% rename from mlir/python/mlir/dialects/ArithmeticOps.td rename to mlir/python/mlir/dialects/ArithOps.td index d14b24a09..aaa9fad21 100644 --- a/mlir/python/mlir/dialects/ArithmeticOps.td +++ b/mlir/python/mlir/dialects/ArithOps.td @@ -1,4 +1,4 @@ -//===-- ArithmeticOps.td - Entry point for ArithmeticOps bindings ---------===// +//===-- ArithOps.td - Entry point for ArithOps bindings ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,10 @@ // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_ARITHMETIC_OPS -#define PYTHON_BINDINGS_ARITHMETIC_OPS +#ifndef PYTHON_BINDINGS_ARITH_OPS +#define PYTHON_BINDINGS_ARITH_OPS include "mlir/Bindings/Python/Attributes.td" -include "mlir/Dialect/Arithmetic/IR/ArithmeticOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" #endif From cee073e4c259299309bf31cd868814ecb6958b28 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Jul 2022 19:02:37 -0700 Subject: [PATCH 355/915] Add APFloat and MLIR type support for fp8 (e5m2). This is a first step towards high level representation for fp8 types that have been built in to hardware with near term roadmaps. Like the BFLOAT16 type, the family of fp8 types are inspired by IEEE-754 binary floating point formats but, due to the size limits, have been tweaked in various ways in order to maximally use the range/precision in various scenarios. The list of variants is small/finite and bounded by real hardware. This patch introduces the E5M2 FP8 format as proposed by Nvidia, ARM, and Intel in the paper: https://arxiv.org/pdf/2209.05433.pdf As the more conformant of the two implemented datatypes, we are plumbing it through LLVM's APFloat type and MLIR's type system first as a template. It will be followed by the range optimized E4M3 FP8 format described in the paper. Since that format deviates further from the IEEE-754 norms, it may require more debate and implementation complexity. Given that we see two parts of the FP8 implementation space represented by these cases, we are recommending naming of: * `F8M` : For FP8 types that can be conceived of as following the same rules as FP16 but with a smaller number of mantissa/exponent bits. Including the number of mantissa bits in the type name is enough to fully specify the type. This naming scheme is used to represent the E5M2 type described in the paper. * `F8MF` : For FP8 types such as E4M3 which only support finite values. The first of these (this patch) seems fairly non-controversial. The second is previewed here to illustrate options for extending to the other known variant (but can be discussed in detail in the patch which implements it). Many conversations about these types focus on the Machine-Learning ecosystem where they are used to represent mixed-datatype computations at a high level. At that level (which is why we also expose them in MLIR), it is important to retain the actual type definition so that when lowering to actual kernels or target specific code, the correct promotions, casts and rescalings can be done as needed. We expect that most LLVM backends will only experience these types as opaque `I8` values that are applicable to some instructions. MLIR does not make it particularly easy to add new floating point types (i.e. the FloatType hierarchy is not open). Given the need to fully model FloatTypes and make them interop with tooling, such types will always be "heavy-weight" and it is not expected that a highly open type system will be particularly helpful. There are also a bounded number of floating point types in use for current and upcoming hardware, and we can just implement them like this (perhaps looking for some cosmetic ways to reduce the number of places that need to change). Creating a more generic mechanism for extending floating point types seems like it wouldn't be worth it and we should just deal with defining them one by one on an as-needed basis when real hardware implements a new scheme. Hopefully, with some additional production use and complete software stacks, hardware makers will converge on a set of such types that is not terribly divergent at the level that the compiler cares about. (I cleaned up some old formatting and sorted some items for this case: If we converge on landing this in some form, I will NFC commit format only changes as a separate commit) Differential Revision: https://reviews.llvm.org/D133823 --- mlir/include/mlir-c/BuiltinTypes.h | 7 +++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index d1083f932..9bd3d510b 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -67,6 +67,13 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is an f8E5M2 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); + +/// Creates an f8E5M2 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index be44b76e8..ad9a5bc66 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -68,6 +68,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// +bool mlirTypeIsAFloat8E5M2(MlirType type) { + return unwrap(type).isFloat8E5M2(); +} + +MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { From 1ad4ef4deedec9aa3a58b904fcb35323ab456a90 Mon Sep 17 00:00:00 2001 From: Vitaly Buka Date: Sun, 2 Oct 2022 21:21:51 -0700 Subject: [PATCH 356/915] Revert "Add APFloat and MLIR type support for fp8 (e5m2)." Breaks bots https://lab.llvm.org/buildbot/#/builders/37/builds/17086 This reverts commit cee073e4c259299309bf31cd868814ecb6958b28. --- mlir/include/mlir-c/BuiltinTypes.h | 7 ------- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 -------- 2 files changed, 15 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 9bd3d510b..d1083f932 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -67,13 +67,6 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// -/// Checks whether the given type is an f8E5M2 type. -MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); - -/// Creates an f8E5M2 type in the given context. The type is owned by the -/// context. -MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); - /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index ad9a5bc66..be44b76e8 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -68,14 +68,6 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// -bool mlirTypeIsAFloat8E5M2(MlirType type) { - return unwrap(type).isFloat8E5M2(); -} - -MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); -} - bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { From 714f7a91fcd8539084f2dffbb7f72609dcfaf798 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 3 Oct 2022 09:38:17 -0700 Subject: [PATCH 357/915] [mlir][c] Init MLProgram C API Add MLIR upstream C api library definition. Differential Revision: https://reviews.llvm.org/D135083 --- mlir/include/mlir-c/Dialect/MLProgram.h | 25 +++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/MLProgram.cpp | 14 ++++++++++++++ 3 files changed, 48 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/MLProgram.h create mode 100644 mlir/lib/CAPI/Dialect/MLProgram.cpp diff --git a/mlir/include/mlir-c/Dialect/MLProgram.h b/mlir/include/mlir-c/Dialect/MLProgram.h new file mode 100644 index 000000000..0874955e3 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/MLProgram.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/MLProgram.h - C API for MLProgram dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MLPROGRAM_H +#define MLIR_C_DIALECT_MLPROGRAM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MLProgram, ml_program); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MLPROGRAM_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index e5173ffd3..2f36040a2 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -60,6 +60,15 @@ add_mlir_upstream_c_api_library(MLIRCAPILinalg MLIRLinalgTransforms ) +add_mlir_upstream_c_api_library(MLIRCAPIMLProgram + MLProgram.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMLProgramDialect +) + add_mlir_upstream_c_api_library(MLIRCAPISCF SCF.cpp diff --git a/mlir/lib/CAPI/Dialect/MLProgram.cpp b/mlir/lib/CAPI/Dialect/MLProgram.cpp new file mode 100644 index 000000000..525b958d9 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/MLProgram.cpp @@ -0,0 +1,14 @@ +//===- MLProgram.cpp - C Interface for MLProgram dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir-c/Dialect/MLProgram.h" +#include "mlir/CAPI/Registration.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MLProgram, ml_program, + mlir::ml_program::MLProgramDialect) From feb5c1bc8c8e43ba8ba26b4c79334ebd06641b07 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 4 Oct 2022 17:06:00 +0900 Subject: [PATCH 358/915] [mlir][tensor][NFC] Rename linalg.init_tensor to tensor.empty tensor.empty/linalg.init_tensor produces an uninititalized tensor that can be used as a destination operand for destination-style ops (ops that implement `DestinationStyleOpInterface`). This change makes it possible to implement `TilingInterface` for non-destination-style ops without depending on the Linalg dialect. RFC: https://discourse.llvm.org/t/rfc-add-tensor-from-shape-operation/65101 Differential Revision: https://reviews.llvm.org/D135129 --- mlir/python/CMakeLists.txt | 4 +- mlir/python/mlir/dialects/_linalg_ops_ext.py | 38 ------------------ mlir/python/mlir/dialects/_tensor_ops_ext.py | 42 ++++++++++++++++++++ 3 files changed, 45 insertions(+), 39 deletions(-) create mode 100644 mlir/python/mlir/dialects/_tensor_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index fe28cb44b..ecff102fe 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -222,7 +222,9 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorOps.td - SOURCES dialects/tensor.py + SOURCES + dialects/tensor.py + dialects/_tensor_ops_ext.py DIALECT_NAME tensor) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index e3fb46055..eb9e969f3 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -20,44 +20,6 @@ def isa(cls: Type, ty: Type): return False -class InitTensorOp: - """Extends the linalg.init_tensor op.""" - - def __init__(self, - sizes: Union[Sequence[int], Sequence[Value]], - element_type: Type, - *, - loc=None, - ip=None): - """Constructs an `init_tensor` with either static or dynamic sizes.""" - context = get_default_loc_context(loc) - operands = [] - attributes = {} - # TODO: Refactor the InitTensorOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - if sizes and isinstance(sizes[0], Value): - # Dynamic sizes. - operands.extend(sizes) - static_size_ints = [-1] * len(sizes) - result_type = RankedTensorType.get(static_size_ints, element_type) - else: - # Static sizes. - result_type = RankedTensorType.get(sizes, element_type) - static_size_ints = sizes - - i64_type = IntegerType.get_signless(64) - attributes["static_sizes"] = ArrayAttr.get( - [IntegerAttr.get(i64_type, s) for s in static_size_ints], - context=context) - op = self.build_generic(results=[result_type], - operands=operands, - attributes=attributes, - loc=loc, - ip=ip) - OpView.__init__(self, op) - - class StructuredOpMixin: """All structured ops use the same mixin class.""" diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py new file mode 100644 index 000000000..0f1b26603 --- /dev/null +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -0,0 +1,42 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Any, Optional, Sequence, Union +from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + + +class EmptyOp: + """Extends the tensor.empty op.""" + + def __init__(self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(-1) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + op = self.build_generic( + results=[result_type], + operands=dynamic_sizes, + attributes={}, + loc=loc, + ip=ip) + OpView.__init__(self, op) From 67cfdad8d643af127c68948b3074d82d02552ddb Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Tue, 4 Oct 2022 10:58:38 +0000 Subject: [PATCH 359/915] [mlir] Add fully dynamic constructor to StridedLayoutAttr bindings Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D135139 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index e62f1550c..0c8c9b8ba 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1050,6 +1050,19 @@ class PyStridedLayoutAttribute }, py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), "Gets a strided layout attribute."); + c.def_static( + "get_fully_dynamic", + [](int64_t rank, DefaultingPyMlirContext ctx) { + auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); + std::vector strides(rank); + std::fill(strides.begin(), strides.end(), dynamic); + MlirAttribute attr = mlirStridedLayoutAttrGet( + ctx->get(), dynamic, strides.size(), strides.data()); + return PyStridedLayoutAttribute(ctx->getRef(), attr); + }, + py::arg("rank"), py::arg("context") = py::none(), + "Gets a strided layout attribute with dynamic offset and strides of a " + "given rank."); c.def_property_readonly( "offset", [](PyStridedLayoutAttribute &self) { From 3c23d497387fa1f1cabe990c4ddfe7f2742e2d4b Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 26 Jul 2022 19:02:37 -0700 Subject: [PATCH 360/915] Add APFloat and MLIR type support for fp8 (e5m2). (Re-Apply with fixes to clang MicrosoftMangle.cpp) This is a first step towards high level representation for fp8 types that have been built in to hardware with near term roadmaps. Like the BFLOAT16 type, the family of fp8 types are inspired by IEEE-754 binary floating point formats but, due to the size limits, have been tweaked in various ways in order to maximally use the range/precision in various scenarios. The list of variants is small/finite and bounded by real hardware. This patch introduces the E5M2 FP8 format as proposed by Nvidia, ARM, and Intel in the paper: https://arxiv.org/pdf/2209.05433.pdf As the more conformant of the two implemented datatypes, we are plumbing it through LLVM's APFloat type and MLIR's type system first as a template. It will be followed by the range optimized E4M3 FP8 format described in the paper. Since that format deviates further from the IEEE-754 norms, it may require more debate and implementation complexity. Given that we see two parts of the FP8 implementation space represented by these cases, we are recommending naming of: * `F8M` : For FP8 types that can be conceived of as following the same rules as FP16 but with a smaller number of mantissa/exponent bits. Including the number of mantissa bits in the type name is enough to fully specify the type. This naming scheme is used to represent the E5M2 type described in the paper. * `F8MF` : For FP8 types such as E4M3 which only support finite values. The first of these (this patch) seems fairly non-controversial. The second is previewed here to illustrate options for extending to the other known variant (but can be discussed in detail in the patch which implements it). Many conversations about these types focus on the Machine-Learning ecosystem where they are used to represent mixed-datatype computations at a high level. At that level (which is why we also expose them in MLIR), it is important to retain the actual type definition so that when lowering to actual kernels or target specific code, the correct promotions, casts and rescalings can be done as needed. We expect that most LLVM backends will only experience these types as opaque `I8` values that are applicable to some instructions. MLIR does not make it particularly easy to add new floating point types (i.e. the FloatType hierarchy is not open). Given the need to fully model FloatTypes and make them interop with tooling, such types will always be "heavy-weight" and it is not expected that a highly open type system will be particularly helpful. There are also a bounded number of floating point types in use for current and upcoming hardware, and we can just implement them like this (perhaps looking for some cosmetic ways to reduce the number of places that need to change). Creating a more generic mechanism for extending floating point types seems like it wouldn't be worth it and we should just deal with defining them one by one on an as-needed basis when real hardware implements a new scheme. Hopefully, with some additional production use and complete software stacks, hardware makers will converge on a set of such types that is not terribly divergent at the level that the compiler cares about. (I cleaned up some old formatting and sorted some items for this case: If we converge on landing this in some form, I will NFC commit format only changes as a separate commit) Differential Revision: https://reviews.llvm.org/D133823 --- mlir/include/mlir-c/BuiltinTypes.h | 7 +++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index d1083f932..9bd3d510b 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -67,6 +67,13 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is an f8E5M2 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); + +/// Creates an f8E5M2 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index be44b76e8..ad9a5bc66 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -68,6 +68,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// +bool mlirTypeIsAFloat8E5M2(MlirType type) { + return unwrap(type).isFloat8E5M2(); +} + +MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { From 6e1e025594dc9e81f13aa58f10c5808a7053248f Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 4 Oct 2022 14:34:37 -0700 Subject: [PATCH 361/915] [mlir][sparse] introduce a higher-order tensor mapping This extension to the sparse tensor type system in MLIR opens up a whole new set of sparse storage schemes, such as block sparse storage (e.g. BCSR) and ELL (aka jagged diagonals). This revision merely introduces the type extension and initial documentation. The actual interpretation of the type (reading in tensors, lowering to code, etc.) will follow. Reviewed By: Peiming Differential Revision: https://reviews.llvm.org/D135206 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 7 ++++++- .../lib/Bindings/Python/DialectSparseTensor.cpp | 17 ++++++++++++++--- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 15 +++++++++++---- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 9465f36c3..765e9c293 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -49,7 +49,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t numDimLevelTypes, enum MlirSparseTensorDimLevelType const *dimLevelTypes, - MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth); + MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, + int pointerBitWidth, int indexBitWidth); /// Returns the number of dim level types in a sparse_tensor.encoding attribute. MLIR_CAPI_EXPORTED intptr_t @@ -63,6 +64,10 @@ mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr); +/// Returns the higher ordering in a sparse_tensor.encoding attribute. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr); + /// Returns the pointer bit width in a sparse_tensor.encoding attribute. MLIR_CAPI_EXPORTED int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr); diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index ae9cfbb67..af47ac8df 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -33,16 +33,18 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { "get", [](py::object cls, std::vector dimLevelTypes, - llvm::Optional dimOrdering, int pointerBitWidth, + llvm::Optional dimOrdering, + llvm::Optional higherOrdering, int pointerBitWidth, int indexBitWidth, MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( context, dimLevelTypes.size(), dimLevelTypes.data(), dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, + higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, pointerBitWidth, indexBitWidth)); }, py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), - py::arg("pointer_bit_width"), py::arg("index_bit_width"), - py::arg("context") = py::none(), + py::arg("higher_ordering"), py::arg("pointer_bit_width"), + py::arg("index_bit_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( "dim_level_types", @@ -64,6 +66,15 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return {}; return ret; }) + .def_property_readonly( + "higher_ordering", + [](MlirAttribute self) -> llvm::Optional { + MlirAffineMap ret = + mlirSparseTensorEncodingAttrGetHigherOrdering(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) .def_property_readonly( "pointer_bit_width", [](MlirAttribute self) { diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index e3e32900c..f0348cecb 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -56,21 +56,28 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t numDimLevelTypes, MlirSparseTensorDimLevelType const *dimLevelTypes, - MlirAffineMap dimOrdering, int pointerBitWidth, int indexBitWidth) { + MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, + int pointerBitWidth, int indexBitWidth) { SmallVector cppDimLevelTypes; cppDimLevelTypes.resize(numDimLevelTypes); for (intptr_t i = 0; i < numDimLevelTypes; ++i) cppDimLevelTypes[i] = static_cast(dimLevelTypes[i]); - return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppDimLevelTypes, - unwrap(dimOrdering), - pointerBitWidth, indexBitWidth)); + return wrap(SparseTensorEncodingAttr::get( + unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), + unwrap(higherOrdering), pointerBitWidth, indexBitWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { return wrap(unwrap(attr).cast().getDimOrdering()); } +MlirAffineMap +mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { + return wrap( + unwrap(attr).cast().getHigherOrdering()); +} + intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { return unwrap(attr).cast().getDimLevelType().size(); } From 0185bd1fb8892dc9018525ceaf163b400b0bff35 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 5 Oct 2022 16:23:14 -0700 Subject: [PATCH 362/915] [mlir][sparse] Adjusting DimLevelType numeric values for faster predicates This differential adjusts the numeric values for DimLevelType values: using the low-order two bits for recording the "No" and "Nu" properties, and the high-order bits for the formats per se. (The choice of encoding may seem a bit peculiar, since the bits are mapped to negative properties rather than positive properties. But this was done in order to preserve the collation order of DimLevelType values. If we don't care about collation order, then we may prefer to flip the semantics of the property bits, so that they're less surprising to readers.) Using distinguished bits for the properties and formats enables faster implementation for the predicates detecting those properties/formats, which matters because this is in the runtime library itself (rather than on the codegen side of things). This differential pushes through the changes to the enum values, and optimizes the basic predicates. However it does not optimize all the places where we check compound predicates (e.g., "is compressed or singleton"), to help reduce rebasing conflict with D134933. Those optimizations will be done after this differential and D134933 are landed. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135004 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 765e9c293..8027f319b 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,15 +26,15 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO, - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO, + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11 }; //===----------------------------------------------------------------------===// From 459e3d063208619d2d544d7d6764bc9a047326f5 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 6 Oct 2022 18:21:47 +0000 Subject: [PATCH 363/915] Apply clang-tidy fixes for performance-unnecessary-value-param in IRCore.cpp (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/lib/Bindings/Python/IRModule.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 389969baa..f706951e6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1029,7 +1029,7 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::writeBytecode(py::object fileObject) { +void PyOperationBase::writeBytecode(const py::object &fileObject) { PyOperation &operation = getOperation(); operation.checkValid(); PyFileAccumulator accum(fileObject, /*binary=*/true); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index ad783c6c3..4738a6fae 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -513,7 +513,7 @@ class PyOperationBase { bool assumeVerified); // Implement the bound 'writeBytecode' method. - void writeBytecode(pybind11::object fileObject); + void writeBytecode(const pybind11::object &fileObject); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); From d01c9edd5308aba722e7bcbc2ab6b08aea31d20f Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 5 Oct 2022 14:23:19 +0000 Subject: [PATCH 364/915] [mlir] switch transform dialect ops to use TransformTypeInterface Use the recently introduced TransformTypeInterface instead of hardcoding the PDLOperationType. This will allow the operations to use more specific transform types to express pre/post-conditions in the future. It requires the syntax and Python op construction API to be updated. Dialect extensions will be switched separately. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135584 --- .../mlir/dialects/_transform_ops_ext.py | 46 ++++++++----------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 992139f72..18cd3adb0 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -5,7 +5,6 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values - from ..dialects import pdl except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -21,9 +20,9 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]): class GetClosestIsolatedParentOp: - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), loc=loc, ip=ip) @@ -38,7 +37,7 @@ def __init__(self, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), [_get_op_result_or_value(h) for h in handles], + [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, loc=loc, ip=ip) @@ -47,13 +46,14 @@ def __init__(self, class PDLMatchOp: def __init__(self, + result_type: Type, target: Union[Operation, Value], pattern_name: Union[Attribute, str], *, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), _get_symbol_ref_attr(pattern_name), loc=loc, @@ -69,7 +69,7 @@ def __init__(self, loc=None, ip=None): super().__init__( - [pdl.OperationType.get()] * len(handles), + [_get_op_result_or_value(h).type for h in handles], _get_op_result_or_value(pattern), [_get_op_result_or_value(h) for h in handles], loc=loc, @@ -78,24 +78,11 @@ def __init__(self, class SequenceOp: - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Sequence[Type], - optionalRoot: Optional[Union[Operation, Value]]): - ... - - @overload - def __init__(self, failure_propagation_mode, - resultsOrRoot: Optional[Union[Operation, - Value]], optionalRoot: NoneType): - ... - - def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=None): - results = resultsOrRoot if isinstance(resultsOrRoot, Sequence) else [] - root = ( - resultsOrRoot - if not isinstance(resultsOrRoot, Sequence) else optionalRoot) - root = _get_op_result_or_value(root) if root else None + def __init__(self, failure_propagation_mode, results: Sequence[Type], + target: Union[Operation, Value, Type]): + root = _get_op_result_or_value(target) if isinstance( + target, (Operation, Value)) else None + root_type = root.type if not isinstance(target, Type) else target if not isinstance(failure_propagation_mode, Attribute): failure_propagation_mode_attr = IntegerAttr.get( IntegerType.get_signless(32), failure_propagation_mode._as_int()) @@ -104,7 +91,7 @@ def __init__(self, failure_propagation_mode, resultsOrRoot=None, optionalRoot=No super().__init__(results_=results, failure_propagation_mode=failure_propagation_mode_attr, root=root) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: @@ -118,15 +105,18 @@ def bodyTarget(self) -> Value: class WithPDLPatternsOp: def __init__(self, - target: Optional[Union[Operation, Value]] = None, + target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type super().__init__( - root=_get_op_result_or_value(target) if target else None, + root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(pdl.OperationType.get()) + self.regions[0].blocks.append(root_type) @property def body(self) -> Block: From c1598cfe11ea56a551274c958212c0b40a8f0a43 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Oct 2022 14:33:59 +0000 Subject: [PATCH 365/915] [mlir] add OperationType to the Transform dialect Add a new OperationType handle type to the Transform dialect. This transform type is parameterized by the name of the payload operation it can point to. It is intended as a constraint on transformations that are only applicable to a specific kind of payload operations. If a transformation is applicable to a small set of operation classes, it can be wrapped into a transform op by using a disjunctive constraint, such as `Type.predicate, Transform_ConcreteOperation<"bar">.predicate]>>` for its operand without modifying this type. Broader sets of accepted operations should be modeled as specific types. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135586 --- mlir/include/mlir-c/Dialect/Transform.h | 46 +++++++++++++ mlir/lib/Bindings/Python/DialectTransform.cpp | 64 +++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++ mlir/lib/CAPI/Dialect/Transform.cpp | 48 ++++++++++++++ mlir/python/CMakeLists.txt | 14 ++++ .../_mlir/dialects/transform/__init__.pyi | 26 ++++++++ .../mlir/dialects/_transform_ops_ext.py | 10 +++ .../mlir/dialects/transform/__init__.py | 1 + 8 files changed, 218 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/Transform.h create mode 100644 mlir/lib/Bindings/Python/DialectTransform.cpp create mode 100644 mlir/lib/CAPI/Dialect/Transform.cpp create mode 100644 mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h new file mode 100644 index 000000000..864dffa3f --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -0,0 +1,46 @@ +//===-- mlir-c/Dialect/Transform.h - C API for Transform Dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_TRANSFORM_H +#define MLIR_C_DIALECT_TRANSFORM_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); + +//===---------------------------------------------------------------------===// +// AnyOpType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType +mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirTransformOperationTypeGetOperationName(MlirType type); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_TRANSFORM_H diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp new file mode 100644 index 000000000..a9db2428c --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -0,0 +1,64 @@ +//===- DialectTransform.cpp - 'transform' dialect submodule ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +void populateDialectTransformSubmodule(const pybind11::module &m) { + //===-------------------------------------------------------------------===// + // AnyOpType + //===-------------------------------------------------------------------===// + + auto anyOpType = + mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType); + anyOpType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirTransformAnyOpTypeGet(ctx)); + }, + "Get an instance of AnyOpType in the given context.", py::arg("cls"), + py::arg("context") = py::none()); + + //===-------------------------------------------------------------------===// + // OperationType + //===-------------------------------------------------------------------===// + + auto operationType = + mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType); + operationType.def_classmethod( + "get", + [](py::object cls, const std::string &operationName, MlirContext ctx) { + MlirStringRef cOperationName = + mlirStringRefCreate(operationName.data(), operationName.size()); + return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); + }, + "Get an instance of OperationType for the given kind in the given " + "context", + py::arg("cls"), py::arg("operation_name"), + py::arg("context") = py::none()); + operationType.def_property_readonly( + "operation_name", + [](MlirType type) { + MlirStringRef operationName = + mlirTransformOperationTypeGetOperationName(type); + return py::str(operationName.data, operationName.length); + }, + "Get the name of the payload operation accepted by the handle."); +} + +PYBIND11_MODULE(_mlirDialectsTransform, m) { + m.doc() = "MLIR Transform dialect."; + populateDialectTransformSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 2f36040a2..6c8454a79 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -116,6 +116,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITensor MLIRTensorDialect ) +add_mlir_upstream_c_api_library(MLIRCAPITransformDialect + Transform.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTransformDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIQuant Quant.cpp diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp new file mode 100644 index 000000000..606b301cc --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -0,0 +1,48 @@ +//===- Transform.cpp - C Interface for Transform dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformTypes.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform, + transform::TransformDialect) + +//===---------------------------------------------------------------------===// +// AnyOpType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyOpType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { + return wrap(transform::AnyOpType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// OperationType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformOperationType(MlirType type) { + return unwrap(type).isa(); +} + +MlirType mlirTransformOperationTypeGet(MlirContext ctx, + MlirStringRef operationName) { + return wrap( + transform::OperationType::get(unwrap(ctx), unwrap(operationName))); +} + +MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { + return wrap(unwrap(type).cast().getOperationName()); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ecff102fe..0a4c2f803 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -121,6 +121,7 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/_transform_ops_ext.py dialects/transform/__init__.py + _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform) declare_mlir_dialect_extension_python_bindings( @@ -353,6 +354,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MLIRCAPISparseTensor ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind + MODULE_NAME _mlirDialectsTransform + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectTransform.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPITransformDialect +) + declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi new file mode 100644 index 000000000..2a2954173 --- /dev/null +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi @@ -0,0 +1,26 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from mlir.ir import Type, Context + + +class AnyOpType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(context: Optional[Context] = None) -> AnyOpType: ... + + +class OperationType(Type): + @staticmethod + def isinstance(type: Type) -> bool: ... + + @staticmethod + def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ... + + @property + def operation_name(self) -> str: ... diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 18cd3adb0..5cd57b050 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -18,6 +18,16 @@ def _get_symbol_ref_attr(value: Union[Attribute, str]): return FlatSymbolRefAttr.get(value) +class CastOp: + + def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + class GetClosestIsolatedParentOp: def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index d4d71274c..78956c437 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,3 +18,4 @@ def _as_int(self): return 2 from .._transform_ops_gen import * +from ..._mlir_libs._mlirDialectsTransform import * From d8e02e72edbad5c4b14c08b99835a4608bca40f2 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 10 Oct 2022 14:38:31 +0000 Subject: [PATCH 366/915] [mlir] switch the transform loop extension to use types Add types to the Loop (SCF) extension of the transform dialect. See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D135587 --- .../python/mlir/dialects/_loop_transform_ops_ext.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index 7452c4243..0dc8fc074 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -5,7 +5,6 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -28,13 +27,14 @@ class GetParentForOp: """Extension for GetParentForOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, num_loops: int = 1, ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), num_loops=_get_int64_attr(num_loops, default_value=1), ip=ip, @@ -45,13 +45,14 @@ class LoopOutlineOp: """Extension for LoopOutlineOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, func_name: Union[str, StringAttr], ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), @@ -63,13 +64,14 @@ class LoopPeelOp: """Extension for LoopPeelOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, fail_if_already_divisible: Union[bool, BoolAttr] = False, ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), fail_if_already_divisible=(fail_if_already_divisible if isinstance( fail_if_already_divisible, BoolAttr) else @@ -82,6 +84,7 @@ class LoopPipelineOp: """Extension for LoopPipelineOp.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, iteration_interval: Optional[Union[int, IntegerAttr]] = None, @@ -89,7 +92,7 @@ def __init__(self, ip=None, loc=None): super().__init__( - pdl.OperationType.get(), + result_type, _get_op_result_or_value(target), iteration_interval=_get_int64_attr(iteration_interval, default_value=1), read_latency=_get_int64_attr(read_latency, default_value=10), From 9b231a28f6fbf6f7c1c41c40fec3489eb4dcb544 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Fri, 14 Oct 2022 19:59:55 +0200 Subject: [PATCH 367/915] [mlir] Simplify DestinationStyleOpInterface. Differential Revision: https://reviews.llvm.org/D135348 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index bfb3313d1..2fb5bc651 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -29,9 +29,9 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { SmallVector argTypes; SmallVector argLocs; - for (OpOperand *opOperand : linalgOp.getInputAndOutputOperands()) { - argTypes.push_back(getElementTypeOrSelf(opOperand->get().getType())); - argLocs.push_back(opOperand->get().getLoc()); + for (OpOperand &opOperand : linalgOp->getOpOperands()) { + argTypes.push_back(getElementTypeOrSelf(opOperand.get().getType())); + argLocs.push_back(opOperand.get().getLoc()); } ImplicitLocOpBuilder b(op->getLoc(), op->getContext()); From b1b6e9e2a4cafd51c92b49207fe347a7ca675d13 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Fri, 14 Oct 2022 16:40:28 -0700 Subject: [PATCH 368/915] [mlir][sparse] Use the runtime DimLevelType instead of a separate tablegen enum This differential replaces all uses of SparseTensorEncodingAttr::DimLevelType with DimLevelType. The next differential will break out a separate library for the DimLevelType enum, so that the Dialect code doesn't need to depend on the rest of the runtime Depends On D135995 Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D135996 --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 31 +++++++++----------------- 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index f0348cecb..b667ad3c6 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -22,31 +22,23 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, // Ensure the C-API enums are int-castable to C++ equivalents. static_assert( static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == - static_cast(SparseTensorEncodingAttr::DimLevelType::Dense) && + static_cast(DimLevelType::Dense) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::Compressed) && + static_cast(DimLevelType::Compressed) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::CompressedNu) && + static_cast(DimLevelType::CompressedNu) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::CompressedNo) && + static_cast(DimLevelType::CompressedNo) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::CompressedNuNo) && + static_cast(DimLevelType::CompressedNuNo) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::Singleton) && + static_cast(DimLevelType::Singleton) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::SingletonNu) && + static_cast(DimLevelType::SingletonNu) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::SingletonNo) && + static_cast(DimLevelType::SingletonNo) && static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) == - static_cast( - SparseTensorEncodingAttr::DimLevelType::SingletonNuNo), + static_cast(DimLevelType::SingletonNuNo), "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { @@ -58,11 +50,10 @@ MlirAttribute mlirSparseTensorEncodingAttrGet( MlirSparseTensorDimLevelType const *dimLevelTypes, MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int pointerBitWidth, int indexBitWidth) { - SmallVector cppDimLevelTypes; + SmallVector cppDimLevelTypes; cppDimLevelTypes.resize(numDimLevelTypes); for (intptr_t i = 0; i < numDimLevelTypes; ++i) - cppDimLevelTypes[i] = - static_cast(dimLevelTypes[i]); + cppDimLevelTypes[i] = static_cast(dimLevelTypes[i]); return wrap(SparseTensorEncodingAttr::get( unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), unwrap(higherOrdering), pointerBitWidth, indexBitWidth)); From ea15232d64ec5f9aeaabadf7a0f81f81849e802f Mon Sep 17 00:00:00 2001 From: Denys Shabalin Date: Thu, 20 Oct 2022 12:58:49 +0200 Subject: [PATCH 369/915] [mlir] Fix and test python bindings for dump_to_object_file Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D136334 --- mlir/include/mlir-c/ExecutionEngine.h | 6 +++--- mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 9 ++++++--- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 4 +++- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index adb4e823e..99cddc5c2 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -42,9 +42,9 @@ DEFINE_C_API_STRUCT(MlirExecutionEngine, void); /// that will be loaded are specified via `numPaths` and `sharedLibPaths` /// respectively. /// TODO: figure out other options. -MLIR_CAPI_EXPORTED MlirExecutionEngine -mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, - const MlirStringRef *sharedLibPaths); +MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( + MlirModule op, int optLevel, int numPaths, + const MlirStringRef *sharedLibPaths, bool enableObjectDump); /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index f5179bd7c..3f8342596 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -72,12 +72,14 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { //---------------------------------------------------------------------------- py::class_(m, "ExecutionEngine", py::module_local()) .def(py::init<>([](MlirModule module, int optLevel, - const std::vector &sharedLibPaths) { + const std::vector &sharedLibPaths, + bool enableObjectDump) { llvm::SmallVector libPaths; for (const std::string &path : sharedLibPaths) libPaths.push_back({path.c_str(), path.length()}); - MlirExecutionEngine executionEngine = mlirExecutionEngineCreate( - module, optLevel, libPaths.size(), libPaths.data()); + MlirExecutionEngine executionEngine = + mlirExecutionEngineCreate(module, optLevel, libPaths.size(), + libPaths.data(), enableObjectDump); if (mlirExecutionEngineIsNull(executionEngine)) throw std::runtime_error( "Failure while creating the ExecutionEngine."); @@ -85,6 +87,7 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { }), py::arg("module"), py::arg("opt_level") = 2, py::arg("shared_libs") = py::list(), + py::arg("enable_object_dump") = true, "Create a new ExecutionEngine instance for the given Module. The " "module must contain only dialects that can be translated to LLVM. " "Perform transformations and code generation at the optimization " diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 2190566b2..a832119ce 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -19,7 +19,8 @@ using namespace mlir; extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, - const MlirStringRef *sharedLibPaths) { + const MlirStringRef *sharedLibPaths, + bool enableObjectDump) { static bool initOnce = [] { llvm::InitializeNativeTarget(); llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm @@ -54,6 +55,7 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, jitOptions.transformer = transformer; jitOptions.jitCodeGenOptLevel = llvmOptLevel; jitOptions.sharedLibPaths = libPaths; + jitOptions.enableObjectDump = enableObjectDump; auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); if (!jitOrError) { consumeError(jitOrError.takeError()); From 3436d5e8a1d42e83aae7cfb5f00e4418fdbb9ddb Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Thu, 20 Oct 2022 12:39:03 +0000 Subject: [PATCH 370/915] [mlir][nfc] Clean-up usage of kDynamicSize. This patch prepares MLIR code base to change the value of kDynamicSize. https://discourse.llvm.org/t/rfc-unify-kdynamicsize-and-kdynamicstrideoroffset/64534/4 Differential Revision: https://reviews.llvm.org/D136327 --- mlir/python/mlir/dialects/_tensor_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py index 0f1b26603..51d998b6e 100644 --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -30,7 +30,7 @@ def __init__(self, if isinstance(s, int): static_sizes.append(s) else: - static_sizes.append(-1) + static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) op = self.build_generic( From db3f0f6e71e0a8ef699b3ef3399700d71ac6f1d5 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 20 Oct 2022 16:40:32 -0400 Subject: [PATCH 371/915] [mlir][python] Include pipeline parse errors in exception message Currently any errors during pipeline parsing are reported to stderr. This adds a new pipeline parsing function to the C api that reports errors through a callback, and updates the python bindings to use it. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D136402 --- mlir/include/mlir-c/Pass.h | 7 +++++++ mlir/lib/Bindings/Python/Pass.cpp | 12 ++++++------ mlir/lib/CAPI/IR/Pass.cpp | 9 +++++++++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index b66bdfe02..6f281b6dc 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -105,6 +105,13 @@ MLIR_CAPI_EXPORTED void mlirPassManagerAddOwnedPass(MlirPassManager passManager, MLIR_CAPI_EXPORTED void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, MlirPass pass); +/// Parse a sequence of textual MLIR pass pipeline elements and add them to the +/// provided OpPassManager. If parsing fails an error message is reported using +/// the provided callback. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOpPassManagerAddPipeline( + MlirOpPassManager passManager, MlirStringRef pipelineElements, + MlirStringCallback callback, void *userData); + /// Print a textual MLIR pass pipeline by sending chunks of the string /// representation and forwarding `userData to `callback`. Note that the /// callback may be called several times with consecutive chunks of the string. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 3278d3a91..99d67582d 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -82,15 +82,15 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { py::arg("enable"), "Enable / disable verify-each.") .def_static( "parse", - [](const std::string pipeline, DefaultingPyMlirContext context) { + [](const std::string &pipeline, DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); - MlirLogicalResult status = mlirParsePassPipeline( + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( mlirPassManagerGetAsOpPassManager(passManager), - mlirStringRefCreate(pipeline.data(), pipeline.size())); + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, - llvm::Twine("invalid pass pipeline '") + - pipeline + "'."); + throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); return new PyPassManager(passManager); }, py::arg("pipeline"), py::arg("context") = py::none(), diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index a2998939a..398abfee2 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -65,6 +65,15 @@ void mlirOpPassManagerAddOwnedPass(MlirOpPassManager passManager, unwrap(passManager)->addPass(std::unique_ptr(unwrap(pass))); } +MlirLogicalResult mlirOpPassManagerAddPipeline(MlirOpPassManager passManager, + MlirStringRef pipelineElements, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + return wrap(parsePassPipeline(unwrap(pipelineElements), *unwrap(passManager), + stream)); +} + void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); From bf81fb9a7a7e0657520e02ac8cea0b15c872c7cc Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 20 Oct 2022 00:51:06 -0400 Subject: [PATCH 372/915] [mlir][CAPI] Allow specifying pass manager anchor This adds a new function for creating pass managers that takes an argument for the anchor string. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D136404 --- mlir/include/mlir-c/Pass.h | 6 +++++- mlir/lib/CAPI/IR/Pass.cpp | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 6f281b6dc..704121a0c 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -51,9 +51,13 @@ DEFINE_C_API_STRUCT(MlirOpPassManager, void); // PassManager/OpPassManager APIs. //===----------------------------------------------------------------------===// -/// Create a new top-level PassManager. +/// Create a new top-level PassManager with the default anchor. MLIR_CAPI_EXPORTED MlirPassManager mlirPassManagerCreate(MlirContext ctx); +/// Create a new top-level PassManager anchored on `anchorOp`. +MLIR_CAPI_EXPORTED MlirPassManager +mlirPassManagerCreateOnOperation(MlirContext ctx, MlirStringRef anchorOp); + /// Destroy the provided PassManager. MLIR_CAPI_EXPORTED void mlirPassManagerDestroy(MlirPassManager passManager); diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 398abfee2..30f580487 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -24,6 +24,11 @@ MlirPassManager mlirPassManagerCreate(MlirContext ctx) { return wrap(new PassManager(unwrap(ctx))); } +MlirPassManager mlirPassManagerCreateOnOperation(MlirContext ctx, + MlirStringRef anchorOp) { + return wrap(new PassManager(unwrap(ctx), unwrap(anchorOp))); +} + void mlirPassManagerDestroy(MlirPassManager passManager) { delete unwrap(passManager); } From 88382181f5563070b64c9214b2181effab7908a7 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Wed, 19 Oct 2022 22:37:12 -0400 Subject: [PATCH 373/915] [mlir][CAPI] Include anchor op in mlirParsePassPipeline The pipeline string must now include the pass manager's anchor op. This makes the parse API properly roundtrip the printed form of a pass manager. Since this is already an API break, I also added an extra callback argument which is used for reporting errors. The old functionality of appending to an existing pass manager is available through `mlirOpPassManagerAddPipeline`. Reviewed By: mehdi_amini, ftynse Differential Revision: https://reviews.llvm.org/D136403 --- mlir/include/mlir-c/Pass.h | 8 +++++--- mlir/lib/CAPI/IR/Pass.cpp | 12 ++++++++---- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 704121a0c..721f1f28f 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -123,10 +123,12 @@ MLIR_CAPI_EXPORTED void mlirPrintPassPipeline(MlirOpPassManager passManager, MlirStringCallback callback, void *userData); -/// Parse a textual MLIR pass pipeline and add it to the provided OpPassManager. - +/// Parse a textual MLIR pass pipeline and assign it to the provided +/// OpPassManager. If parsing fails an error message is reported using the +/// provided callback. MLIR_CAPI_EXPORTED MlirLogicalResult -mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline); +mlirParsePassPipeline(MlirOpPassManager passManager, MlirStringRef pipeline, + MlirStringCallback callback, void *userData); //===----------------------------------------------------------------------===// // External Pass API. diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 30f580487..4afc66859 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -86,10 +86,14 @@ void mlirPrintPassPipeline(MlirOpPassManager passManager, } MlirLogicalResult mlirParsePassPipeline(MlirOpPassManager passManager, - MlirStringRef pipeline) { - // TODO: errors are sent to std::errs() at the moment, we should pass in a - // stream and redirect to a diagnostic. - return wrap(mlir::parsePassPipeline(unwrap(pipeline), *unwrap(passManager))); + MlirStringRef pipeline, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + FailureOr pm = parsePassPipeline(unwrap(pipeline), stream); + if (succeeded(pm)) + *unwrap(passManager) = std::move(*pm); + return wrap(pm); } //===----------------------------------------------------------------------===// From 075927e5fb726dcfd0149e032e48931d82212aa9 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Wed, 19 Oct 2022 23:36:15 -0400 Subject: [PATCH 374/915] [mlir][python] Include anchor op in PassManager.parse The pipeline string must now include the pass manager's anchor op. This makes the parse API properly roundtrip the printed form of a pass manager. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D136405 --- mlir/lib/Bindings/Python/Pass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 99d67582d..f08a4bd2d 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -85,7 +85,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](const std::string &pipeline, DefaultingPyMlirContext context) { MlirPassManager passManager = mlirPassManagerCreate(context->get()); PyPrintAccumulator errorMsg; - MlirLogicalResult status = mlirOpPassManagerAddPipeline( + MlirLogicalResult status = mlirParsePassPipeline( mlirPassManagerGetAsOpPassManager(passManager), mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); From 65ebd33ef2f9d3b340262c2a310e5264884ab958 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 20 Oct 2022 01:04:34 -0400 Subject: [PATCH 375/915] [mlir][python] Include anchor op in PassManager constructor This adds an extra argument for specifying the pass manager's anchor op, with a default of `any`. Previously the anchor was always defaulted to `builtin.module`. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D136406 --- mlir/lib/Bindings/Python/Pass.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index f08a4bd2d..13f1cfa35 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -56,11 +56,14 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { // Mapping of the top-level PassManager //---------------------------------------------------------------------------- py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](DefaultingPyMlirContext context) { - MlirPassManager passManager = - mlirPassManagerCreate(context->get()); + .def(py::init<>([](const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), + py::arg("anchor_op") = py::str("any"), py::arg("context") = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, From a8168bd610e70d348cf1a1a0103e55e6ba45c387 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Thu, 20 Oct 2022 00:27:09 -0400 Subject: [PATCH 376/915] [mlir][python] Allow adding to existing pass manager This adds a `PassManager.add` method which adds pipeline elements to the pass manager. This allows for progressively building up a pipeline from python without string manipulation. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D137344 --- mlir/lib/Bindings/Python/Pass.cpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 13f1cfa35..cb3c1586e 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -100,6 +100,20 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") + .def( + "add", + [](PyPassManager &passManager, const std::string &pipeline) { + PyPrintAccumulator errorMsg; + MlirLogicalResult status = mlirOpPassManagerAddPipeline( + mlirPassManagerGetAsOpPassManager(passManager.get()), + mlirStringRefCreate(pipeline.data(), pipeline.size()), + errorMsg.getCallback(), errorMsg.getUserData()); + if (mlirLogicalResultIsFailure(status)) + throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + }, + py::arg("pipeline"), + "Add textual pipeline elements to the pass manager. Throws a " + "ValueError if the pipeline can't be parsed.") .def( "run", [](PyPassManager &passManager, PyModule &module) { From 477de017001c9d8e6f6859fcb645fd42fd49c415 Mon Sep 17 00:00:00 2001 From: bixia1 Date: Wed, 9 Nov 2022 17:33:25 -0800 Subject: [PATCH 377/915] [mlir][sparse] Fix Python interface for bufferization.alloc_tensor. Add size_hint operand to the Python interface. Fix pytaco. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D137754 --- mlir/python/mlir/dialects/_bufferization_ops_ext.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py index 23f78fc80..6ed35f444 100644 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -19,6 +19,7 @@ def __init__(self, tensor_type: Type, dynamic_sizes: Sequence[Value], copy: Value, + size_hint: Value, escape: BoolAttr, *, loc=None, @@ -30,7 +31,7 @@ def __init__(self, attributes["escape"] = escape op = self.build_generic( results=[tensor_type], - operands=[dynamic_sizes, copy], + operands=[dynamic_sizes, copy, size_hint], attributes=attributes, loc=loc, ip=ip) From 20fe0fcb42daa0ced9d2cacaa07b356bb5d28c74 Mon Sep 17 00:00:00 2001 From: Oleg Shyshkov Date: Thu, 10 Nov 2022 11:42:48 +0100 Subject: [PATCH 378/915] Revert "Revert "[mlir][linalg] Replace "string" iterator_types attr with enums in LinalgInterface."" With python code fixed. This reverts commit 41280908e43d47903960c66237ab49caa5641b4d. --- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index e7493617a..b63cb4071 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -127,8 +127,10 @@ def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, # TODO: Support emission of pure memref form. indexing_maps_attr = ArrayAttr.get( [AffineMapAttr.get(am) for am in indexing_maps]) - iterator_types_attr = ArrayAttr.get( - [StringAttr.get(s) for s in op_config.iterator_types]) + iterator_types_attr = ArrayAttr.get([ + Attribute.parse(f"#linalg.iterator_type<{s}>") + for s in op_config.iterator_types + ]) # Compute the index attributes used when emitting a named structured op. index_attrs = {} # type: Dict[str, DenseElementAttr] @@ -180,7 +182,8 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, # An operation is rank polymorphic if the iteration domain has rank zero. if not iterator_types_attr: rank = ShapedType(outs[0].type).rank - iterator_types_attr = ArrayAttr.get([StringAttr.get("parallel")] * rank) + iterator_types_attr = ArrayAttr.get( + [Attribute.parse("#linalg.iterator_type")] * rank) scalar_map = AffineMap.get(rank, 0, []) tensor_map = AffineMap.get_identity(rank) indexing_maps = [] From bc02a755b6f2384260d28431303000027a77b04d Mon Sep 17 00:00:00 2001 From: Reed Date: Wed, 16 Nov 2022 10:24:24 +0100 Subject: [PATCH 379/915] Add Float8E4M3FN type to MLIR. The paper https://arxiv.org/abs/2209.05433 introduces two new FP8 dtypes: E5M2 (called Float8E5M2 in LLVM) and E4M3 (called Float8E4M3FN in LLVM). Support for Float8E5M2 in APFloat and MLIR was added in https://reviews.llvm.org/D133823. Support for Float8E4M3FN in APFloat was added in https://reviews.llvm.org/D137760. This change adds Float8E4M3FN to MLIR as well. There is an RFC for adding the FP8 dtypes here: https://discourse.llvm.org/t/rfc-add-apfloat-and-mlir-type-support-for-fp8-e5m2/65279. This change is identical to the MLIR changes in the patch that added Float8E5M2, except that Float8E4M3FN is added instead. Reviewed By: stellaraccident, bkramer, rriddle Differential Revision: https://reviews.llvm.org/D138075 --- mlir/include/mlir-c/BuiltinTypes.h | 7 +++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 9bd3d510b..1c4a16382 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -74,6 +74,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); + +/// Creates an f8E4M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index ad9a5bc66..596a760b9 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -76,6 +76,14 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3FN(MlirType type) { + return unwrap(type).isFloat8E4M3FN(); +} + +MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { From ae77ff18a2304bb76bcd07e6618849e4cda29c72 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 17 Nov 2022 20:44:27 -0800 Subject: [PATCH 380/915] [mlir][AsmPrinter] Allow explicitly disabling debug info This adds an `enable` flag to OpPrintingFlags::enableDebugInfo that allows for overriding any command line flags for debug printing, and matches the format that we use for other `enableBlah` API. --- mlir/include/mlir-c/IR.h | 9 +++++---- mlir/lib/Bindings/Python/IRCore.cpp | 3 ++- mlir/lib/CAPI/IR/IR.cpp | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index daf097da2..b4266bd0a 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -376,11 +376,12 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit); -/// Enable printing of debug information. If 'prettyForm' is set to true, -/// debug information is printed in a more readable 'pretty' form. Note: The -/// IR generated with 'prettyForm' is not parsable. +/// Enable or disable printing of debug information (based on `enable`). If +/// 'prettyForm' is set to true, debug information is printed in a more readable +/// 'pretty' form. Note: The IR generated with 'prettyForm' is not parsable. MLIR_CAPI_EXPORTED void -mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool prettyForm); +mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, + bool prettyForm); /// Always print operations in the generic form. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f706951e6..a183809a3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1017,7 +1017,8 @@ void PyOperationBase::print(py::object fileObject, bool binary, if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); if (enableDebugInfo) - mlirOpPrintingFlagsEnableDebugInfo(flags, /*prettyForm=*/prettyDebugInfo); + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/prettyDebugInfo); if (printGenericOpForm) mlirOpPrintingFlagsPrintGenericOpForm(flags); if (useLocalScope) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 98a3ff348..53b530d3f 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -127,9 +127,9 @@ void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); } -void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, +void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm) { - unwrap(flags)->enableDebugInfo(/*prettyForm=*/prettyForm); + unwrap(flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); } void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { From c2613d00599383a18a8126cf62e1809f4fbeafa1 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Fri, 18 Nov 2022 18:00:10 +0000 Subject: [PATCH 381/915] Merge kDynamicSize and kDynamicSentinel into one constant. resolve conflicts Differential Revision: https://reviews.llvm.org/D138282 --- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 596a760b9..6b6ba6e95 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -165,18 +165,18 @@ int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { return unwrap(type).cast().getDimSize(static_cast(dim)); } -int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamicSize; } +int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { - return ShapedType::isDynamicStrideOrOffset(val); + return ShapedType::isDynamic(val); } int64_t mlirShapedTypeGetDynamicStrideOrOffset() { - return ShapedType::kDynamicStrideOrOffset; + return ShapedType::kDynamic; } //===----------------------------------------------------------------------===// From db7d1e4a0b1fb51617bf441dd085700a64eccbd1 Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Tue, 22 Nov 2022 12:41:44 +0100 Subject: [PATCH 382/915] [MLIR] Adopt `DenseI64ArrayAttr` in tensor, memref and linalg transform This commit is a first step toward removing inconsistencies between dynamic and static attributes (i64 v. index) by dropping `I64ArrayAttr` and using `DenseI64ArrayAttr` in Tensor, Memref and Linalg Transform ops. In Linalg Transform ops only `TileToScfForOp` and `TileOp` have been updated. See related discussion: https://discourse.llvm.org/t/rfc-inconsistency-between-dynamic-and-static-attributes-i64-v-index/66612/1 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D138567 --- .../dialects/_structured_transform_ops_ext.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 527a8656f..5fd5cfe10 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -49,6 +49,15 @@ def _get_int_array_attr( return ArrayAttr.get([_get_int64_attr(v) for v in values]) +def _get_dense_int64_array_attr( + values: Sequence[int]) -> DenseI64ArrayAttr: + """Creates a dense integer array from a sequence of integers. + Expects the thread-local MLIR context to have been set by the context + manager. + """ + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -250,14 +259,11 @@ def __init__(self, else: for size in sizes: if isinstance(size, int): - static_sizes.append(IntegerAttr.get(i64_type, size)) - elif isinstance(size, IntegerAttr): static_sizes.append(size) else: - static_sizes.append( - IntegerAttr.get(i64_type, ShapedType.get_dynamic_size())) + static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = ArrayAttr.get(static_sizes) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) @@ -266,14 +272,14 @@ def __init__(self, _get_op_result_or_value(target), dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_int_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) if interchange else None, loc=loc, ip=ip) - def __extract_values(self, attr: Optional[ArrayAttr]) -> List[int]: + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: if not attr: return [] - return [IntegerAttr(element).value for element in attr] + return [element for element in attr] class VectorizeOp: From cd309e1942d82ef515c0601373fa7b0420cd392d Mon Sep 17 00:00:00 2001 From: Prashant Kumar Date: Thu, 1 Dec 2022 16:27:33 +0000 Subject: [PATCH 383/915] Replacing `is` with `==` for the dtype check. >>> a = np.ndarray([1,1]).astype(np.half) >>> a array([[0.007812]], dtype=float16) >>> a.dtype dtype('float16') >>> a.dtype == np.half True >>> a.dtype == np.float16 True >>> a.dtype is np.float16 False Checking with `is` leads to inconsistency in checking. Reviewed By: silvas Differential Revision: https://reviews.llvm.org/D139121 --- mlir/python/mlir/runtime/np_to_memref.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 5b3c3c4ae..d70967983 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -23,13 +23,14 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] +# https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" - if dtp is np.dtype(np.complex128): + if dtp == np.dtype(np.complex128): return C128 - if dtp is np.dtype(np.complex64): + if dtp == np.dtype(np.complex64): return C64 - if dtp is np.dtype(np.float16): + if dtp == np.dtype(np.float16): return F16 return np.ctypeslib.as_ctypes_type(dtp) From 879edc831fa3e10b4fd171e268e134932fd8dcb8 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 3 Dec 2022 18:50:27 -0800 Subject: [PATCH 384/915] [mlir] Use std::nullopt instead of None (NFC) This patch mechanically replaces None with std::nullopt where the compiler would warn if None were deprecated. The intent is to reduce the amount of manual work required in migrating from Optional to std::optional. This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 --- mlir/include/mlir/CAPI/Wrap.h | 2 +- mlir/lib/CAPI/IR/Pass.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h index b8cc745d7..5b68f417a 100644 --- a/mlir/include/mlir/CAPI/Wrap.h +++ b/mlir/include/mlir/CAPI/Wrap.h @@ -44,7 +44,7 @@ static llvm::ArrayRef unwrapList(size_t size, CTy *first, "incompatible C and C++ types"); if (size == 0) - return llvm::None; + return std::nullopt; assert(storage.empty() && "expected to populate storage"); storage.reserve(size); diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 4afc66859..6f81cd808 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -178,7 +178,7 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, void *userData) { return wrap(static_cast(new mlir::ExternalPass( unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), - opName.length > 0 ? Optional(unwrap(opName)) : None, + opName.length > 0 ? Optional(unwrap(opName)) : std::nullopt, {dependentDialects, static_cast(nDependentDialects)}, callbacks, userData))); } From 86a3e8989f9ada7080a40b82e1f08ce1225d5348 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 5 Dec 2022 17:42:38 -0800 Subject: [PATCH 385/915] Don't use root logger at import time At import time, these calls to `logging.debug()` implicitly call `logging.basicConfig` (https://docs.python.org/3/library/logging.html#logging.basicConfig), setting logging config for the whole project which cannot then be overwritten later. For instance, consider the following test script: ``` import logging import jax logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) logger.info('info') ``` This should log out `'info'`, but because when `import jax` is called, this `_mlir_lib/__init__.py` file is run and a `logging.debug` is called, calling `logging.basicConfig`, my `logging.basicConfig(level=logging.INFO)` does nothing. Fix: instead of using root logger, use a module level logger. Found in this issue: https://github.com/google/jax/issues/12526 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D134812 --- mlir/python/mlir/_mlir_libs/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index b140ad64e..9ceeef818 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -54,6 +54,7 @@ def _site_initialize(): import itertools import logging from ._mlir import ir + logger = logging.getLogger(__name__) registry = ir.DialectRegistry() post_init_hooks = [] @@ -66,14 +67,14 @@ def process_initializer_module(module_name): message = (f"Error importing mlir initializer {module_name}. This may " "happen in unclean incremental builds but is likely a real bug if " "encountered otherwise and the MLIR Python API may not function.") - logging.warning(message, exc_info=True) + logger.warning(message, exc_info=True) - logging.debug("Initializing MLIR with module: %s", module_name) + logger.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): - logging.debug("Registering dialects from initializer %r", m) + logger.debug("Registering dialects from initializer %r", m) m.register_dialects(registry) if hasattr(m, "context_init_hook"): - logging.debug("Adding context init hook from %r", m) + logger.debug("Adding context init hook from %r", m) post_init_hooks.append(m.context_init_hook) return True From 289b4fcaa22138ed4035f4468c4cfc9268dadd29 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Tue, 6 Dec 2022 00:03:44 -0800 Subject: [PATCH 386/915] [mlir] Use std::nullopt instead of None in comments (NFC) This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/lib/Bindings/Python/IRModule.cpp | 10 +++++----- mlir/lib/Bindings/Python/IRTypes.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index a183809a3..8c25f6e81 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2626,7 +2626,7 @@ void mlir::python::populateIRCore(py::module &m) { "__str__", [](PyOperationBase &self) { return self.getAsm(/*binary=*/false, - /*largeElementsLimit=*/llvm::None, + /*largeElementsLimit=*/std::nullopt, /*enableDebugInfo=*/false, /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index ba6b2d29f..b6d1df51f 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -91,14 +91,14 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { if (foundIt->second.is_none()) - return llvm::None; + return std::nullopt; assert(foundIt->second && "py::object is defined"); return foundIt->second; } // Not found and loading did not yield a registration. Negative cache. dialectClassMap[dialectNamespace] = py::none(); - return llvm::None; + return std::nullopt; } llvm::Optional @@ -107,7 +107,7 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { auto foundIt = rawOpViewClassMapCache.find(operationName); if (foundIt != rawOpViewClassMapCache.end()) { if (foundIt->second.is_none()) - return llvm::None; + return std::nullopt; assert(foundIt->second && "py::object is defined"); return foundIt->second; } @@ -123,7 +123,7 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { auto foundIt = rawOpViewClassMap.find(operationName); if (foundIt != rawOpViewClassMap.end()) { if (foundIt->second.is_none()) - return llvm::None; + return std::nullopt; assert(foundIt->second && "py::object is defined"); // Positive cache. rawOpViewClassMapCache[operationName] = foundIt->second; @@ -131,7 +131,7 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { } // Negative cache. rawOpViewClassMap[operationName] = py::none(); - return llvm::None; + return std::nullopt; } } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 379510ce9..7a41cb1e8 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -390,7 +390,7 @@ class PyRankedTensorType [](PyRankedTensorType &self) -> llvm::Optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) - return llvm::None; + return std::nullopt; return PyAttribute(self.getContext(), encoding); }); } From ba5c9de5eca272baca15ddfb43b84c0895955329 Mon Sep 17 00:00:00 2001 From: Tom Eccles Date: Mon, 5 Dec 2022 12:07:51 +0000 Subject: [PATCH 387/915] [mlir] Fix -Wstrict-prototypes warning These warnings prevent compilation using clang and -DLLVM_ENABLE_WERROR=On. Differential revision: https://reviews.llvm.org/D139322 --- mlir/include/mlir-c/BuiltinAttributes.h | 2 +- mlir/include/mlir-c/BuiltinTypes.h | 4 ++-- mlir/include/mlir-c/Dialect/Quant.h | 2 +- mlir/include/mlir-c/IR.h | 16 +++++++++------- mlir/include/mlir-c/RegisterEverything.h | 2 +- mlir/include/mlir-c/Support.h | 6 +++--- 6 files changed, 17 insertions(+), 15 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 79f22376e..8887897bc 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -23,7 +23,7 @@ extern "C" { #endif /// Returns an empty attribute. -MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(); +MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void); //===----------------------------------------------------------------------===// // Affine map attribute. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 1c4a16382..000397505 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -166,7 +166,7 @@ MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); /// Returns the value indicating a dynamic size in a shaped type. Prefer /// mlirShapedTypeIsDynamicSize to direct comparisons with this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(); +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. @@ -175,7 +175,7 @@ MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); /// Returns the value indicating a dynamic stride or offset in a shaped type. /// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with /// this value. -MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(); +MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); //===----------------------------------------------------------------------===// // Vector type. diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index 39a17318c..a7d98dc3c 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -26,7 +26,7 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(quant, quant); MLIR_CAPI_EXPORTED bool mlirTypeIsAQuantizedType(MlirType type); /// Returns the bit flag used to indicate signedness of a quantized type. -MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetSignedFlag(); +MLIR_CAPI_EXPORTED unsigned mlirQuantizedTypeGetSignedFlag(void); /// Returns the minimum possible value stored by a quantized type. MLIR_CAPI_EXPORTED int64_t mlirQuantizedTypeGetDefaultMinimumForInteger( diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b4266bd0a..cd4f4d394 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -82,7 +82,7 @@ typedef struct MlirNamedAttribute MlirNamedAttribute; //===----------------------------------------------------------------------===// /// Creates an MLIR context and transfers its ownership to the caller. -MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(); +MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void); /// Checks if two contexts are equal. MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2); @@ -184,7 +184,8 @@ struct MlirDialectHandle { typedef struct MlirDialectHandle MlirDialectHandle; #define MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Name, Namespace) \ - MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__() + MLIR_CAPI_EXPORTED MlirDialectHandle mlirGetDialectHandle__##Namespace##__( \ + void) /// Returns the namespace associated with the provided dialect handle. MLIR_CAPI_EXPORTED @@ -208,7 +209,7 @@ MLIR_CAPI_EXPORTED MlirDialect mlirDialectHandleLoadDialect(MlirDialectHandle, //===----------------------------------------------------------------------===// /// Creates a dialect registry and transfers its ownership to the caller. -MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate(); +MLIR_CAPI_EXPORTED MlirDialectRegistry mlirDialectRegistryCreate(void); /// Checks if the dialect registry is null. static inline bool mlirDialectRegistryIsNull(MlirDialectRegistry registry) { @@ -363,7 +364,7 @@ mlirOperationStateEnableResultTypeInference(MlirOperationState *state); /// Creates new printing flags with defaults, intended for customization. /// Must be freed with a call to mlirOpPrintingFlagsDestroy(). -MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(); +MLIR_CAPI_EXPORTED MlirOpPrintingFlags mlirOpPrintingFlagsCreate(void); /// Destroys printing flags created with mlirOpPrintingFlagsCreate. MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsDestroy(MlirOpPrintingFlags flags); @@ -551,7 +552,7 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, //===----------------------------------------------------------------------===// /// Creates a new empty region and transfers ownership to the caller. -MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(); +MLIR_CAPI_EXPORTED MlirRegion mlirRegionCreate(void); /// Takes a region owned by the caller and destroys it. MLIR_CAPI_EXPORTED void mlirRegionDestroy(MlirRegion region); @@ -817,10 +818,11 @@ MLIR_CAPI_EXPORTED MlirStringRef mlirIdentifierStr(MlirIdentifier ident); /// Returns the name of the attribute used to store symbol names compatible with /// symbol tables. -MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(); +MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetSymbolAttributeName(void); /// Returns the name of the attribute used to store symbol visibility. -MLIR_CAPI_EXPORTED MlirStringRef mlirSymbolTableGetVisibilityAttributeName(); +MLIR_CAPI_EXPORTED MlirStringRef +mlirSymbolTableGetVisibilityAttributeName(void); /// Creates a symbol table for the given operation. If the operation does not /// have the SymbolTable trait, returns a null symbol table. diff --git a/mlir/include/mlir-c/RegisterEverything.h b/mlir/include/mlir-c/RegisterEverything.h index b98ce154d..ea2ea8644 100644 --- a/mlir/include/mlir-c/RegisterEverything.h +++ b/mlir/include/mlir-c/RegisterEverything.h @@ -29,7 +29,7 @@ MLIR_CAPI_EXPORTED void mlirRegisterAllDialects(MlirDialectRegistry registry); MLIR_CAPI_EXPORTED void mlirRegisterAllLLVMTranslations(MlirContext context); /// Register all compiler passes of MLIR. -MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(); +MLIR_CAPI_EXPORTED void mlirRegisterAllPasses(void); #ifdef __cplusplus } diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 5d20fb78d..8d0188e31 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -127,13 +127,13 @@ inline static bool mlirLogicalResultIsFailure(MlirLogicalResult res) { } /// Creates a logical result representing a success. -inline static MlirLogicalResult mlirLogicalResultSuccess() { +inline static MlirLogicalResult mlirLogicalResultSuccess(void) { MlirLogicalResult res = {1}; return res; } /// Creates a logical result representing a failure. -inline static MlirLogicalResult mlirLogicalResultFailure() { +inline static MlirLogicalResult mlirLogicalResultFailure(void) { MlirLogicalResult res = {0}; return res; } @@ -160,7 +160,7 @@ MLIR_CAPI_EXPORTED size_t mlirTypeIDHashValue(MlirTypeID typeID); //===----------------------------------------------------------------------===// /// Creates a type id allocator for dynamic type id creation -MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate(); +MLIR_CAPI_EXPORTED MlirTypeIDAllocator mlirTypeIDAllocatorCreate(void); /// Deallocates the allocator and all allocated type ids MLIR_CAPI_EXPORTED void From a320cd73361b0c393ec4371dbe06e60c1ddf8ec4 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 6 Dec 2022 15:56:08 -0700 Subject: [PATCH 388/915] [mlir][CAPI] Add a simple MlirOpOperand API for MlirValue uses. This allows basic IR traversal via the C API, which is useful for analyses in languages other than C++. This starts by defining an MlirOpOperand struct to encapsulate a pair of an owner operation and an operand number. A new API is added for MlirValue, to return the first use of the Value as an MlirOpOperand, or a "null" MlirOpOperand if there are no uses. A couple APIs are added for MlirOpOperand. The first checks if an MlirOpOperand is "null", by checking if its owner's pointer is null. The second supports iteration along the use-def lists by accepting an MlirOpOperand and returning the next use of the Value as another MlirOpOperand, or a "null" MlirOpOperand if there are no more uses. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D139596 --- mlir/include/mlir-c/IR.h | 24 ++++++++++++++++++++++++ mlir/include/mlir/CAPI/IR.h | 1 + mlir/lib/CAPI/IR/IR.cpp | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index cd4f4d394..7817da2bd 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -52,6 +52,7 @@ DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirDialectRegistry, void); DEFINE_C_API_STRUCT(MlirOperation, void); +DEFINE_C_API_STRUCT(MlirOpOperand, void); DEFINE_C_API_STRUCT(MlirOpPrintingFlags, void); DEFINE_C_API_STRUCT(MlirBlock, void); DEFINE_C_API_STRUCT(MlirRegion, void); @@ -728,6 +729,29 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Returns an op operand representing the first use of the value, or a null op +/// operand if there are no uses. +MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +/// Returns whether the op operand is null. +MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand); + +/// Returns the owner operation of an op operand. +MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand); + +/// Returns the operand number of an op operand. +MLIR_CAPI_EXPORTED unsigned +mlirOpOperandGetOperandNumber(MlirOpOperand opOperand); + +/// Returns an op operand representing the next use of the value, or a null op +/// operand if there is no next use. +MLIR_CAPI_EXPORTED MlirOpOperand +mlirOpOperandGetNextUse(MlirOpOperand opOperand); + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 899b41167..2f32c76e1 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -25,6 +25,7 @@ DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) DEFINE_C_API_PTR_METHODS(MlirOperation, mlir::Operation) DEFINE_C_API_PTR_METHODS(MlirBlock, mlir::Block) +DEFINE_C_API_PTR_METHODS(MlirOpOperand, mlir::OpOperand) DEFINE_C_API_PTR_METHODS(MlirOpPrintingFlags, mlir::OpPrintingFlags) DEFINE_C_API_PTR_METHODS(MlirRegion, mlir::Region) DEFINE_C_API_PTR_METHODS(MlirSymbolTable, mlir::SymbolTable) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 53b530d3f..f5364c864 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -719,6 +719,43 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } +MlirOpOperand mlirValueGetFirstUse(MlirValue value) { + Value cppValue = unwrap(value); + if (cppValue.use_empty()) + return {}; + + OpOperand *opOperand = cppValue.use_begin().getOperand(); + + return wrap(opOperand); +} + +//===----------------------------------------------------------------------===// +// OpOperand API. +//===----------------------------------------------------------------------===// + +bool mlirOpOperandIsNull(MlirOpOperand opOperand) { return !opOperand.ptr; } + +MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { + return wrap(unwrap(opOperand)->getOwner()); +} + +unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { + return unwrap(opOperand)->getOperandNumber(); +} + +MlirOpOperand mlirOpOperandGetNextUse(MlirOpOperand opOperand) { + if (mlirOpOperandIsNull(opOperand)) + return {}; + + OpOperand *nextOpOperand = static_cast( + unwrap(opOperand)->getNextOperandUsingThisValue()); + + if (!nextOpOperand) + return {}; + + return wrap(nextOpOperand); +} + //===----------------------------------------------------------------------===// // Type API. //===----------------------------------------------------------------------===// From 4e0e9dd131c3d1ed87e380c62e01860b9d93bc65 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Wed, 7 Dec 2022 13:03:24 -0700 Subject: [PATCH 389/915] [mlir][Python] Add `__hash__` implementation for Block. This allows us to hash Blocks and use them in sets or parts of larger hashable objects. The implementation is the same as other core IR constructs: the C API object's pointer is hashed. Differential Revision: https://reviews.llvm.org/D139599 --- mlir/lib/Bindings/Python/IRCore.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8c25f6e81..0a32ff598 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2892,6 +2892,10 @@ void mlir::python::populateIRCore(py::module &m) { return self.get().ptr == other.get().ptr; }) .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) + .def("__hash__", + [](PyBlock &self) { + return static_cast(llvm::hash_value(self.get().ptr)); + }) .def( "__str__", [](PyBlock &self) { From fef7aa3770e2212ef827ba6628a93984e4b810e4 Mon Sep 17 00:00:00 2001 From: Mike Urbach Date: Tue, 6 Dec 2022 19:30:56 -0700 Subject: [PATCH 390/915] [mlir][Python] Add a simple PyOpOperand iterator for PyValue uses. This adds a simple PyOpOperand based on MlirOpOperand, which can has properties for the owner op and operation number. This also adds a PyOpOperandIterator that defines methods for __iter__ and __next__ so PyOpOperands can be iterated over using the the MlirOpOperand C API. Finally, a uses psuedo-container is added to PyValue so the uses can generically be iterated. Depends on D139596 Reviewed By: stellaraccident, jdd Differential Revision: https://reviews.llvm.org/D139597 --- mlir/lib/Bindings/Python/IRCore.cpp | 56 +++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0a32ff598..b46fe44e9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -447,6 +447,55 @@ class PyOperationList { MlirBlock block; }; +class PyOpOperand { +public: + PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} + + py::object getOwner() { + MlirOperation owner = mlirOpOperandGetOwner(opOperand); + PyMlirContextRef context = + PyMlirContext::forContext(mlirOperationGetContext(owner)); + return PyOperation::forOperation(context, owner)->createOpView(); + } + + size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } + + static void bind(py::module &m) { + py::class_(m, "OpOperand", py::module_local()) + .def_property_readonly("owner", &PyOpOperand::getOwner) + .def_property_readonly("operand_number", + &PyOpOperand::getOperandNumber); + } + +private: + MlirOpOperand opOperand; +}; + +class PyOpOperandIterator { +public: + PyOpOperandIterator(MlirOpOperand opOperand) : opOperand(opOperand) {} + + PyOpOperandIterator &dunderIter() { return *this; } + + PyOpOperand dunderNext() { + if (mlirOpOperandIsNull(opOperand)) + throw py::stop_iteration(); + + PyOpOperand returnOpOperand(opOperand); + opOperand = mlirOpOperandGetNextUse(opOperand); + return returnOpOperand; + } + + static void bind(py::module &m) { + py::class_(m, "OpOperandIterator", py::module_local()) + .def("__iter__", &PyOpOperandIterator::dunderIter) + .def("__next__", &PyOpOperandIterator::dunderNext); + } + +private: + MlirOpOperand opOperand; +}; + } // namespace //------------------------------------------------------------------------------ @@ -3156,6 +3205,11 @@ void mlir::python::populateIRCore(py::module &m) { assert(false && "Value must be a block argument or an op result"); return py::none(); }) + .def_property_readonly("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; @@ -3182,6 +3236,7 @@ void mlir::python::populateIRCore(py::module &m) { }); PyBlockArgument::bind(m); PyOpResult::bind(m); + PyOpOperand::bind(m); //---------------------------------------------------------------------------- // Mapping of SymbolTable. @@ -3220,6 +3275,7 @@ void mlir::python::populateIRCore(py::module &m) { PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); + PyOpOperandIterator::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); PyRegionIterator::bind(m); From 3931cae1702c82ef267dd572e8b6819fe96a4833 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 16 Dec 2022 13:36:15 -0800 Subject: [PATCH 391/915] [mlir-c] Add method to create unmanaged dense resource elements attr Following DenseElementsAttr pattern. Differential Revision: https://reviews.llvm.org/D140189 --- mlir/include/mlir-c/BuiltinAttributes.h | 71 ++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 143 ++++++++++++++++++++++++ 2 files changed, 214 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 8887897bc..3b1ac466d 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -510,6 +510,77 @@ mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED const void * mlirDenseElementsAttrGetRawData(MlirAttribute attr); +//===----------------------------------------------------------------------===// +// Resource blob attributes. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int8_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int16_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint32_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int32_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint64_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int64_t *elements); +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const float *elements); +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const double *elements); + +/// Returns the pos-th value (flat contiguous indexing) of a specific type +/// contained by the given dense resource elements attribute. +MLIR_CAPI_EXPORTED bool +mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int8_t +mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint8_t +mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int16_t +mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint16_t +mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int32_t +mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint32_t +mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED int64_t +mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED float +mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED double +mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos); + //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 05ecb0fe8..e392d053b 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -7,10 +7,13 @@ //===----------------------------------------------------------------------===// #include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" +#include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" using namespace mlir; @@ -698,6 +701,146 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { unwrap(attr).cast().getRawData().data()); } +//===----------------------------------------------------------------------===// +// Resource blob attributes. +//===----------------------------------------------------------------------===// + +template +static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, + intptr_t numElements, const T *elements) { + return wrap(U::get(unwrap(shapedType).cast(), unwrap(name), + UnmanagedAsmResourceBlob::allocateInferAlign( + llvm::makeArrayRef(elements, numElements)))); +} + +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint8_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint16_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint32_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const uint64_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int8_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int16_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int32_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const int64_t *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const float *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} +MLIR_CAPI_EXPORTED MlirAttribute +mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, + MlirStringRef name, + intptr_t numElements, + const double *elements) { + return getDenseResource(shapedType, name, + numElements, elements); +} + +template +static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { + return (*unwrap(attr).cast().tryGetAsArrayRef())[pos]; +} + +MLIR_CAPI_EXPORTED bool +mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED uint8_t +mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED uint16_t +mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +MLIR_CAPI_EXPORTED uint32_t +mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +MLIR_CAPI_EXPORTED uint64_t +mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, + pos); +} +MLIR_CAPI_EXPORTED int8_t +mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED int16_t +mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED int32_t +mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED int64_t +mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED float +mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} +MLIR_CAPI_EXPORTED double +mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { + return getDenseResourceVal(attr, pos); +} + //===----------------------------------------------------------------------===// // Sparse elements attribute. //===----------------------------------------------------------------------===// From 800976003846bb583a8409edb81f0c00ffd2d670 Mon Sep 17 00:00:00 2001 From: Murali Vijayaraghavan Date: Fri, 16 Dec 2022 04:49:45 +0000 Subject: [PATCH 392/915] [mlir][linalg] Creating named 1D pooling ops This is mostly going to be used for linalg transformations - to make pooling ops similar to convolution ops. Differential Revision: https://reviews.llvm.org/D140186 --- .../linalg/opdsl/ops/core_named_ops.py | 141 +++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 1aa112dcf..8bab1607b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -694,7 +694,6 @@ def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) - @linalg_structured_op def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), @@ -838,6 +837,146 @@ def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +@linalg_structured_op +def pooling_nwc_sum(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + + +@linalg_structured_op +def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) + + +@linalg_structured_op +def pooling_nwc_max(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KW, + index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, + D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,])) + + +@linalg_structured_op +def pooling_nwc_min(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + +@linalg_structured_op +def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, + S.KW, + index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, + D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) + + @linalg_structured_op def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, From 64b2b21d8ce546debea8478d80a9eea2f1e1025a Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Wed, 14 Dec 2022 11:39:19 +0100 Subject: [PATCH 393/915] mlir/tblgen: use std::optional in generation This is part of an effort to migrate from llvm::Optional to std::optional. This patch changes the way mlir-tblgen generates .inc files, and modifies tests and documentation appropriately. It is a "no compromises" patch, and doesn't leave the user with an unpleasant mix of llvm::Optional and std::optional. A non-trivial change has been made to ControlFlowInterfaces to split one constructor into two, relating to a build failure on Windows. See also: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 Signed-off-by: Ramkumar Ramachandra Differential Revision: https://reviews.llvm.org/D138934 --- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 66d3bb187..ec2cf0175 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -46,7 +46,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( if (!info) return mlirLogicalResultFailure(); - llvm::Optional maybeLocation; + std::optional maybeLocation; if (!mlirLocationIsNull(location)) maybeLocation = unwrap(location); SmallVector unwrappedOperands; From ec07b8fb81c213ea065f7dd4a8f44d9b70454ce6 Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Mon, 19 Dec 2022 04:28:55 +0000 Subject: [PATCH 394/915] [mlir][python] llvm::Optional::value => operator* And convert it to std::optional while updating. --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/lib/Bindings/Python/IRModule.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b46fe44e9..794be9742 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -841,7 +841,7 @@ py::tuple PyDiagnostic::getNotes() { materializedNotes = py::tuple(numNotes); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - materializedNotes.value()[i] = PyDiagnostic(noteDiag); + (*materializedNotes)[i] = PyDiagnostic(noteDiag); } return *materializedNotes; } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 4738a6fae..2492ad5d1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -302,7 +302,7 @@ class PyDiagnostic { /// If notes have been materialized from the diagnostic, then this will /// be populated with the corresponding objects (all castable to /// PyDiagnostic). - llvm::Optional materializedNotes; + std::optional materializedNotes; bool valid = true; }; From fb2b958ba131691c86e979c34e8e030689d78bcc Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Mon, 19 Dec 2022 23:38:55 +0100 Subject: [PATCH 395/915] [mlir] Add operator!= to WalkResult, for completeness. --- mlir/lib/Bindings/Python/IRAffine.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index fc7133b43..09f36b07c 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -15,6 +15,7 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/IntegerSet.h" +#include "llvm/ADT/Hashing.h" namespace py = pybind11; using namespace mlir; From df25ab81300df67e71d3f4f1ed6a34f27686ac5a Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Tue, 20 Dec 2022 00:17:25 +0100 Subject: [PATCH 396/915] [ADT] Alias llvm::Optional to std::optional This avoids the continuous API churn when upgrading things to use std::optional and makes trivial string replace upgrades possible. I tested this with GCC 7.5, the oldest supported GCC I had around. Differential Revision: https://reviews.llvm.org/D140332 --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 3 --- mlir/lib/Bindings/Python/PybindUtils.h | 3 --- 2 files changed, 6 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 564425b9b..38c547078 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -35,9 +35,6 @@ namespace py = pybind11; namespace pybind11 { namespace detail { -template -struct type_caster> : optional_caster> {}; - /// Helper to convert a presumed MLIR API object to a capsule, accepting either /// an explicit Capsule (which can happen when two C APIs are communicating /// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 5356cbd54..d039a8acd 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -94,9 +94,6 @@ struct MlirDefaultingCaster { return pybind11::cast(src, policy); } }; - -template -struct type_caster> : optional_caster> {}; } // namespace detail } // namespace pybind11 From 3e2857ac74683fbd0264314ffa4297c4d614796b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 21 Dec 2022 10:10:31 -0800 Subject: [PATCH 397/915] [mlir][py] Enable building ops with raw inputs For cases where we can automatically construct the Attribute allow for more user-friendly input. This is consistent with C++ builder generation as well choice of which single builder to generate here (most specialized/user-friendly). Registration of attribute builders from more pythonic input is all Python side. The downside is that * extra checking to see if user provided a custom builder in op builders, * the ODS attribute name is load bearing upside is that * easily change these/register dialect specific ones in downstream projects, * adding support/changing to different convenience builders are all along with the rest of the convenience functions in Python (and no additional changes to tablegen file or recompilation needed); Allow for both building with Attributes as well as raw inputs. This change should therefore be backwards compatible as well as allow for avoiding recreating Attribute where already available. Differential Revision: https://reviews.llvm.org/D139568 --- mlir/lib/Bindings/Python/Globals.h | 12 ++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 26 +++++++++++++++++ mlir/lib/Bindings/Python/IRModule.cpp | 27 ++++++++++++++++++ mlir/python/mlir/ir.py | 41 +++++++++++++++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 6613d2b69..ba6cfb545 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,6 +58,12 @@ class PyGlobals { /// have a DIALECT_NAMESPACE attribute. pybind11::object registerDialectDecorator(pybind11::object pyClass); + /// Adds a user-friendly Attribute builder. + /// Raises an exception if the mapping already exists. + /// This is intended to be called by implementation code. + void registerAttributeBuilder(const std::string &attributeKind, + pybind11::function pyFunc); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -71,6 +77,10 @@ class PyGlobals { pybind11::object pyClass, pybind11::object rawOpViewClass); + /// Returns the custom Attribute builder for Attribute kind. + std::optional + lookupAttributeBuilder(const std::string &attributeKind); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional @@ -92,6 +102,8 @@ class PyGlobals { /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; + /// Map of attribute ODS name to custom builder. + llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 794be9742..f2aa8da5b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,29 @@ struct PyGlobalDebugFlag { } }; +struct PyAttrBuilderMap { + static bool dunderContains(const std::string &attributeKind) { + return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); + } + static py::function dundeGetItemNamed(const std::string &attributeKind) { + auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); + if (!builder) + throw py::key_error(); + return *builder; + } + static void dundeSetItemNamed(const std::string &attributeKind, + py::function func) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + } + + static void bind(py::module &m) { + py::class_(m, "AttrBuilder", py::module_local()) + .def_static("contains", &PyAttrBuilderMap::dunderContains) + .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) { // Debug bindings. PyGlobalDebugFlag::bind(m); + + // Attribute builder getter. + PyAttrBuilderMap::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index b6d1df51f..be6de5fd2 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { loadedDialectModulesCache.insert(dialectNamespace); } +void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, + py::function pyFunc) { + py::function &found = attributeBuilderMap[attributeKind]; + if (found) { + throw std::runtime_error((llvm::Twine("Attribute builder for '") + + attributeKind + "' is already registered") + .str()); + } + found = std::move(pyFunc); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } +std::optional +PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + // Fast match against the class map first (common case). + const auto foundIt = attributeBuilderMap.find(attributeKind); + if (foundIt != attributeBuilderMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + attributeBuilderMap[attributeKind] = py::none(); + return std::nullopt; +} + llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 99e88ff74..19986917d 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,3 +4,44 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func + return decorator_builder + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x: bool, context: Context): + return BoolAttr.get(x, context=context) + +@register_attribute_builder("IndexAttr") +def _indexAttr(x: int, context: Context): + return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I32Attr") +def _i32Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) + +@register_attribute_builder("I64Attr") +def _i64Attr(x: int, context: Context): + return IntegerAttr.get( + IntegerType.get_signless(64, context=context), x) + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x: str, context: Context): + return StringAttr.get(x, context=context) + +try: + import numpy as np + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x: list[int], context: Context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get(context=context), + context=context) +except ImportError: + pass From 2c72bdde56329eb66b5b187bd422d54d34555b59 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 21 Dec 2022 14:53:12 -0800 Subject: [PATCH 398/915] Revert "[mlir][py] Enable building ops with raw inputs" Reverting to fix build bot. This reverts commit 3e2857ac74683fbd0264314ffa4297c4d614796b. --- mlir/lib/Bindings/Python/Globals.h | 12 -------- mlir/lib/Bindings/Python/IRCore.cpp | 26 ----------------- mlir/lib/Bindings/Python/IRModule.cpp | 27 ------------------ mlir/python/mlir/ir.py | 41 --------------------------- 4 files changed, 106 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index ba6cfb545..6613d2b69 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,12 +58,6 @@ class PyGlobals { /// have a DIALECT_NAMESPACE attribute. pybind11::object registerDialectDecorator(pybind11::object pyClass); - /// Adds a user-friendly Attribute builder. - /// Raises an exception if the mapping already exists. - /// This is intended to be called by implementation code. - void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc); - /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -77,10 +71,6 @@ class PyGlobals { pybind11::object pyClass, pybind11::object rawOpViewClass); - /// Returns the custom Attribute builder for Attribute kind. - std::optional - lookupAttributeBuilder(const std::string &attributeKind); - /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional @@ -102,8 +92,6 @@ class PyGlobals { /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; - /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f2aa8da5b..794be9742 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,29 +194,6 @@ struct PyGlobalDebugFlag { } }; -struct PyAttrBuilderMap { - static bool dunderContains(const std::string &attributeKind) { - return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); - } - static py::function dundeGetItemNamed(const std::string &attributeKind) { - auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); - if (!builder) - throw py::key_error(); - return *builder; - } - static void dundeSetItemNamed(const std::string &attributeKind, - py::function func) { - PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); - } - - static void bind(py::module &m) { - py::class_(m, "AttrBuilder", py::module_local()) - .def_static("contains", &PyAttrBuilderMap::dunderContains) - .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); - } -}; - //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3306,7 +3283,4 @@ void mlir::python::populateIRCore(py::module &m) { // Debug bindings. PyGlobalDebugFlag::bind(m); - - // Attribute builder getter. - PyAttrBuilderMap::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index be6de5fd2..b6d1df51f 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -60,17 +60,6 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { loadedDialectModulesCache.insert(dialectNamespace); } -void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc) { - py::function &found = attributeBuilderMap[attributeKind]; - if (found) { - throw std::runtime_error((llvm::Twine("Attribute builder for '") + - attributeKind + "' is already registered") - .str()); - } - found = std::move(pyFunc); -} - void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -95,22 +84,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } -std::optional -PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { - // Fast match against the class map first (common case). - const auto foundIt = attributeBuilderMap.find(attributeKind); - if (foundIt != attributeBuilderMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::function is defined"); - return foundIt->second; - } - - // Not found and loading did not yield a registration. Negative cache. - attributeBuilderMap[attributeKind] = py::none(); - return std::nullopt; -} - llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 19986917d..99e88ff74 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,44 +4,3 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug - - -# Convenience decorator for registering user-friendly Attribute builders. -def register_attribute_builder(kind): - def decorator_builder(func): - AttrBuilder.insert(kind, func) - return func - return decorator_builder - - -@register_attribute_builder("BoolAttr") -def _boolAttr(x: bool, context: Context): - return BoolAttr.get(x, context=context) - -@register_attribute_builder("IndexAttr") -def _indexAttr(x: int, context: Context): - return IntegerAttr.get(IndexType.get(context=context), x) - -@register_attribute_builder("I32Attr") -def _i32Attr(x: int, context: Context): - return IntegerAttr.get( - IntegerType.get_signless(32, context=context), x) - -@register_attribute_builder("I64Attr") -def _i64Attr(x: int, context: Context): - return IntegerAttr.get( - IntegerType.get_signless(64, context=context), x) - -@register_attribute_builder("SymbolNameAttr") -def _symbolNameAttr(x: str, context: Context): - return StringAttr.get(x, context=context) - -try: - import numpy as np - @register_attribute_builder("IndexElementsAttr") - def _indexElementsAttr(x: list[int], context: Context): - return DenseElementsAttr.get( - np.array(x, dtype=np.int64), type=IndexType.get(context=context), - context=context) -except ImportError: - pass From 5622d407f042eac6ecf8f8589f38c18eb81acbf2 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 21 Dec 2022 16:22:39 -0800 Subject: [PATCH 399/915] Revert "Revert "[mlir][py] Enable building ops with raw inputs"" Fix Python 3.6.9 issue encountered due to type checking here. Will add back in follow up. This reverts commit 2c72bdde56329eb66b5b187bd422d54d34555b59. --- mlir/lib/Bindings/Python/Globals.h | 12 ++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 26 +++++++++++++++++ mlir/lib/Bindings/Python/IRModule.cpp | 27 ++++++++++++++++++ mlir/python/mlir/ir.py | 41 +++++++++++++++++++++++++++ 4 files changed, 106 insertions(+) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 6613d2b69..ba6cfb545 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,6 +58,12 @@ class PyGlobals { /// have a DIALECT_NAMESPACE attribute. pybind11::object registerDialectDecorator(pybind11::object pyClass); + /// Adds a user-friendly Attribute builder. + /// Raises an exception if the mapping already exists. + /// This is intended to be called by implementation code. + void registerAttributeBuilder(const std::string &attributeKind, + pybind11::function pyFunc); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -71,6 +77,10 @@ class PyGlobals { pybind11::object pyClass, pybind11::object rawOpViewClass); + /// Returns the custom Attribute builder for Attribute kind. + std::optional + lookupAttributeBuilder(const std::string &attributeKind); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. llvm::Optional @@ -92,6 +102,8 @@ class PyGlobals { /// Map of operation name to custom subclass that directly initializes /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; + /// Map of attribute ODS name to custom builder. + llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 794be9742..f2aa8da5b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,29 @@ struct PyGlobalDebugFlag { } }; +struct PyAttrBuilderMap { + static bool dunderContains(const std::string &attributeKind) { + return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); + } + static py::function dundeGetItemNamed(const std::string &attributeKind) { + auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); + if (!builder) + throw py::key_error(); + return *builder; + } + static void dundeSetItemNamed(const std::string &attributeKind, + py::function func) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + } + + static void bind(py::module &m) { + py::class_(m, "AttrBuilder", py::module_local()) + .def_static("contains", &PyAttrBuilderMap::dunderContains) + .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + } +}; + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3283,4 +3306,7 @@ void mlir::python::populateIRCore(py::module &m) { // Debug bindings. PyGlobalDebugFlag::bind(m); + + // Attribute builder getter. + PyAttrBuilderMap::bind(m); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index b6d1df51f..be6de5fd2 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -60,6 +60,17 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { loadedDialectModulesCache.insert(dialectNamespace); } +void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, + py::function pyFunc) { + py::function &found = attributeBuilderMap[attributeKind]; + if (found) { + throw std::runtime_error((llvm::Twine("Attribute builder for '") + + attributeKind + "' is already registered") + .str()); + } + found = std::move(pyFunc); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -84,6 +95,22 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } +std::optional +PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + // Fast match against the class map first (common case). + const auto foundIt = attributeBuilderMap.find(attributeKind); + if (foundIt != attributeBuilderMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + + // Not found and loading did not yield a registration. Negative cache. + attributeBuilderMap[attributeKind] = py::none(); + return std::nullopt; +} + llvm::Optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 99e88ff74..82468e8b7 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,3 +4,44 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug + + +# Convenience decorator for registering user-friendly Attribute builders. +def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func + return decorator_builder + + +@register_attribute_builder("BoolAttr") +def _boolAttr(x, context): + return BoolAttr.get(x, context=context) + +@register_attribute_builder("IndexAttr") +def _indexAttr(x, context): + return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I32Attr") +def _i32Attr(x, context): + return IntegerAttr.get( + IntegerType.get_signless(32, context=context), x) + +@register_attribute_builder("I64Attr") +def _i64Attr(x, context): + return IntegerAttr.get( + IntegerType.get_signless(64, context=context), x) + +@register_attribute_builder("SymbolNameAttr") +def _symbolNameAttr(x, context): + return StringAttr.get(x, context=context) + +try: + import numpy as np + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), type=IndexType.get(context=context), + context=context) +except ImportError: + pass From 529fb1722ef1c5f717c7649fb900ccb81232db35 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 27 Dec 2022 21:56:58 -0800 Subject: [PATCH 400/915] [mlir][py] Fix negative cached value in attribute builder Previously this was incorrectly assigning py::none to where function was expected which resulted in failure if one used a non-attribute for attribute without registered builder. --- mlir/lib/Bindings/Python/Globals.h | 2 +- mlir/lib/Bindings/Python/IRModule.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index ba6cfb545..f0bf3f556 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -103,7 +103,7 @@ class PyGlobals { /// the OpView base class (bypassing the user class constructor). llvm::StringMap rawOpViewClassMap; /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; + llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index be6de5fd2..1cdd7e441 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -62,7 +62,7 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, py::function pyFunc) { - py::function &found = attributeBuilderMap[attributeKind]; + py::object &found = attributeBuilderMap[attributeKind]; if (found) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered") From 724b7c050ad8991da35090bab183a10f98edb701 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 28 Dec 2022 16:02:08 -0800 Subject: [PATCH 401/915] [mlir][py] Add StrAttr convenience builder. --- mlir/python/mlir/ir.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 82468e8b7..1e24fcbf9 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -32,6 +32,10 @@ def _i64Attr(x, context): return IntegerAttr.get( IntegerType.get_signless(64, context=context), x) +@register_attribute_builder("StrAttr") +def _stringAttr(x, context): + return StringAttr.get(x, context=context) + @register_attribute_builder("SymbolNameAttr") def _symbolNameAttr(x, context): return StringAttr.get(x, context=context) From 863c730f692f3ec8cf74d32e4a3bbb7f69642db4 Mon Sep 17 00:00:00 2001 From: Qiao Zhang Date: Tue, 3 Jan 2023 19:06:30 +0000 Subject: [PATCH 402/915] [mlir][python] Expose fp8 types with pybind. Expose fp8 types with pybind. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D140746 --- mlir/lib/Bindings/Python/IRTypes.cpp | 38 ++++++++++++++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 16 ++++++++++ 2 files changed, 54 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 7a41cb1e8..10527af6c 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -102,6 +102,42 @@ class PyIndexType : public PyConcreteType { } }; +/// Floating Point Type subclass - Float8E4M3FNType. +class PyFloat8E4M3FNType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fn type."); + } +}; + +/// Floating Point Type subclass - Float8M5E2Type. +class PyFloat8E5M2Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2 type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -663,6 +699,8 @@ class PyOpaqueType : public PyConcreteType { void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyIndexType::bind(m); + PyFloat8E4M3FNType::bind(m); + PyFloat8E5M2Type::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 60bc3676f..505946ca1 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -50,6 +50,8 @@ __all__ = [ "DiagnosticHandler", "DiagnosticSeverity", "DictAttr", + "Float8E4M3FNType", + "Float8E5M2Type", "F16Type", "F32Type", "F64Type", @@ -577,6 +579,20 @@ class DictAttr(Attribute): @property def type(self) -> Type: ... +class Float8E4M3FNType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FNType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E5M2Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... From 442ab0cc694ef79acd60b38ef9da4480b763f890 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 6 Jan 2023 10:13:32 -0800 Subject: [PATCH 403/915] [mlir][tosa] Add tosa.conv3d lowering to Linalg Conv3D has an existing linalg operation for floating point. Adding a quantized variant and corresponding lowering from TOSA. Numerical correctness was validated using the TOSA conformance tests. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D140919 --- .../linalg/opdsl/ops/core_named_ops.py | 85 +++++++++++++++---- 1 file changed, 70 insertions(+), 15 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 8bab1607b..4402624c1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -150,6 +150,7 @@ def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) + @linalg_structured_op def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.K, S.N), @@ -162,8 +163,9 @@ def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), """ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k] * TypeFn.cast_signed( - U, B[D.b, D.k, D.n])) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])) + @linalg_structured_op def matvec(A=TensorDef(T1, S.M, S.N), @@ -283,6 +285,7 @@ def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( U, K[D.kw, D.c, D.f]) + @linalg_structured_op def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.F, S.C, S.KW), @@ -304,6 +307,7 @@ def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( U, K[D.f, D.c, D.kw]) + @linalg_structured_op def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), @@ -400,13 +404,15 @@ def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) + @linalg_structured_op -def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): +def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): """Performs 2-D grouped convolution. Layout: @@ -420,7 +426,8 @@ def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + @linalg_structured_op def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, @@ -449,6 +456,43 @@ def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, U, K[D.kd, D.kh, D.kw, D.c, D.f]) +@linalg_structured_op +def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, + S.F, + output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed( + U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) - + TypeFn.cast_signed(U, KZp)) + + @linalg_structured_op def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), @@ -517,7 +561,8 @@ def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, @linalg_structured_op -def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, +def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, + S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), K=TensorDef(T2, S.IC, S.KH, S.KW), O=TensorDef(U, @@ -539,7 +584,8 @@ def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S implements(ConvolutionOpInterface) domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + + D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) @linalg_structured_op @@ -642,7 +688,11 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, output=True), strides=IndexAttrDef(S.SD, S.SH, @@ -667,12 +717,17 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, - S.N, S.OD * S.SD + S.KD * S.DD, +def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, + S.OD * S.SD + S.KD * S.DD, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, + O=TensorDef(U, + S.N, + S.OD, + S.OH, + S.OW, + S.CM, output=True), strides=IndexAttrDef(S.SD, S.SH, From a3a386bb2afa0b52fb1211c8220c5fc13238ab45 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Fri, 6 Jan 2023 21:29:04 +0100 Subject: [PATCH 404/915] [mlir] Add header file for ssize_t ssize_t is part of POSIX and not standard C/C++, so using ssize_t without the necessary header files causes the build to fail on Windows with the following error: 'ssize_t': undeclared identifier. This patch includes llvm/Support/DataTypes.h to resolve the problem. Differential Revision: https://reviews.llvm.org/D141149 --- mlir/lib/Bindings/Python/PybindUtils.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index d039a8acd..245dc4621 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -12,6 +12,7 @@ #include "mlir-c/Support.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" +#include "llvm/Support/DataTypes.h" #include #include From 01a90c727c16d9b475c9bc2fc421fe27b4247b02 Mon Sep 17 00:00:00 2001 From: serge-sans-paille Date: Mon, 9 Jan 2023 18:11:07 +0100 Subject: [PATCH 405/915] Move from llvm::makeArrayRef to ArrayRef deduction guides - last part This is a follow-up to https://reviews.llvm.org/D140896, split into several parts as it touches a lot of files. Differential Revision: https://reviews.llvm.org/D141298 --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/lib/CAPI/Dialect/Quant.cpp | 2 +- mlir/lib/CAPI/IR/AffineMap.cpp | 2 +- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 11 +++++------ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 27 +++++++++++++------------- mlir/lib/CAPI/IR/IntegerSet.cpp | 2 +- 6 files changed, 22 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f2aa8da5b..f1c7be5e4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2496,7 +2496,7 @@ void mlir::python::populateIRCore(py::module &m) { throw py::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : - llvm::reverse(llvm::makeArrayRef(frames).drop_back())) + llvm::reverse(llvm::ArrayRef(frames).drop_back())) caller = mlirLocationCallSiteGet(frame.get(), caller); return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 483536508..065ab3e36 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -167,7 +167,7 @@ MlirType mlirUniformQuantizedPerAxisTypeGet( int64_t storageTypeMax) { return wrap(quant::UniformQuantizedPerAxisType::get( flags, unwrap(storageType), unwrap(expressedType), - llvm::makeArrayRef(scales, nDims), llvm::makeArrayRef(zeroPoints, nDims), + llvm::ArrayRef(scales, nDims), llvm::ArrayRef(zeroPoints, nDims), quantizedDimension, storageTypeMin, storageTypeMax)); } diff --git a/mlir/lib/CAPI/IR/AffineMap.cpp b/mlir/lib/CAPI/IR/AffineMap.cpp index 85557bc57..1889765ef 100644 --- a/mlir/lib/CAPI/IR/AffineMap.cpp +++ b/mlir/lib/CAPI/IR/AffineMap.cpp @@ -68,7 +68,7 @@ MlirAffineMap mlirAffineMapMinorIdentityGet(MlirContext ctx, intptr_t dims, MlirAffineMap mlirAffineMapPermutationGet(MlirContext ctx, intptr_t size, unsigned *permutation) { return wrap(AffineMap::getPermutationMap( - llvm::makeArrayRef(permutation, static_cast(size)), unwrap(ctx))); + llvm::ArrayRef(permutation, static_cast(size)), unwrap(ctx))); } bool mlirAffineMapIsIdentity(MlirAffineMap affineMap) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index e392d053b..b6ee4af79 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -301,13 +301,13 @@ MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return wrap(unwrap(attr) .cast() - .getValues()[llvm::makeArrayRef(idxs, rank)]); + .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { return unwrap(attr).cast().isValidIndex( - llvm::makeArrayRef(idxs, rank)); + llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { @@ -520,9 +520,8 @@ template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), - llvm::makeArrayRef(elements, numElements))); + return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + llvm::ArrayRef(elements, numElements))); } MlirAttribute mlirDenseElementsAttrUInt8Get(MlirType shapedType, @@ -710,7 +709,7 @@ static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { return wrap(U::get(unwrap(shapedType).cast(), unwrap(name), UnmanagedAsmResourceBlob::allocateInferAlign( - llvm::makeArrayRef(elements, numElements)))); + llvm::ArrayRef(elements, numElements)))); } MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 6b6ba6e95..73a3ec414 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -187,15 +187,14 @@ bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { - return wrap( - VectorType::get(llvm::makeArrayRef(shape, static_cast(rank)), - unwrap(elementType))); + return wrap(VectorType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType))); } MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, const int64_t *shape, MlirType elementType) { return wrap(VectorType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType))); } @@ -215,9 +214,9 @@ bool mlirTypeIsAUnrankedTensor(MlirType type) { MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType, MlirAttribute encoding) { - return wrap(RankedTensorType::get( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), - unwrap(encoding))); + return wrap( + RankedTensorType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), unwrap(encoding))); } MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, @@ -225,7 +224,7 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, MlirType elementType, MlirAttribute encoding) { return wrap(RankedTensorType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), unwrap(encoding))); } @@ -252,7 +251,7 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, MlirAttribute memorySpace) { return wrap(MemRefType::get( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), + llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() : unwrap(layout).cast(), @@ -264,7 +263,7 @@ MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirAttribute layout, MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() @@ -275,9 +274,9 @@ MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, MlirType mlirMemRefTypeContiguousGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute memorySpace) { - return wrap(MemRefType::get( - llvm::makeArrayRef(shape, static_cast(rank)), unwrap(elementType), - MemRefLayoutAttrInterface(), unwrap(memorySpace))); + return wrap(MemRefType::get(llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), MemRefLayoutAttrInterface(), + unwrap(memorySpace))); } MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, @@ -285,7 +284,7 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, const int64_t *shape, MlirAttribute memorySpace) { return wrap(MemRefType::getChecked( - unwrap(loc), llvm::makeArrayRef(shape, static_cast(rank)), + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), MemRefLayoutAttrInterface(), unwrap(memorySpace))); } diff --git a/mlir/lib/CAPI/IR/IntegerSet.cpp b/mlir/lib/CAPI/IR/IntegerSet.cpp index 701d70353..43d48e415 100644 --- a/mlir/lib/CAPI/IR/IntegerSet.cpp +++ b/mlir/lib/CAPI/IR/IntegerSet.cpp @@ -49,7 +49,7 @@ MlirIntegerSet mlirIntegerSetGet(MlirContext context, intptr_t numDims, return wrap(IntegerSet::get( static_cast(numDims), static_cast(numSymbols), mlirConstraints, - llvm::makeArrayRef(eqFlags, static_cast(numConstraints)))); + llvm::ArrayRef(eqFlags, static_cast(numConstraints)))); } MlirIntegerSet From b7eed789f010848077ea7140b769dc766ca542ea Mon Sep 17 00:00:00 2001 From: Jeff Niu Date: Thu, 12 Jan 2023 14:17:32 -0800 Subject: [PATCH 406/915] [mlir][python] fix python build --- mlir/python/mlir/dialects/_memref_ops_ext.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py index 9cc22a21c..a00a087be 100644 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -4,7 +4,8 @@ try: from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ._ods_common import get_op_results_or_values as _get_op_results_or_values except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -30,8 +31,6 @@ def __init__(self, loc: user-visible location of the operation. ip: insertion point. """ - memref_resolved = _get_op_result_or_value(memref) indices_resolved = [] if indices is None else _get_op_results_or_values( indices) - return_type = MemRefType(memref_resolved.type).element_type - super().__init__(return_type, memref, indices_resolved, loc=loc, ip=ip) + super().__init__(memref, indices_resolved, loc=loc, ip=ip) From a4054d3b29abbc68e10e419d0af72acda7e8348d Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Fri, 13 Jan 2023 21:05:06 -0800 Subject: [PATCH 407/915] [mlir] Add #include (NFC) This patch adds #include to those files containing llvm::Optional<...> or Optional<...>. I'll post a separate patch to actually replace llvm::Optional with std::optional. This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 --- mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 1 + mlir/lib/Bindings/Python/Globals.h | 1 + mlir/lib/Bindings/Python/IRAttributes.cpp | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 1 + mlir/lib/Bindings/Python/IRInterfaces.cpp | 1 + mlir/lib/Bindings/Python/IRModule.cpp | 1 + mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/lib/Bindings/Python/IRTypes.cpp | 1 + mlir/lib/CAPI/IR/IR.cpp | 1 + mlir/lib/CAPI/IR/Pass.cpp | 1 + mlir/lib/CAPI/Interfaces/Interfaces.cpp | 1 + 11 files changed, 11 insertions(+) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index af47ac8df..6885ef9cf 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -9,6 +9,7 @@ #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index f0bf3f556..fbe51a753 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -11,6 +11,7 @@ #include #include +#include #include "PybindUtils.h" diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 0c8c9b8ba..c598aee7b 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include +#include #include "IRModule.h" diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f1c7be5e4..7a9e8eb3a 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SmallVector.h" #include +#include namespace py = pybind11; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 3f50d3bc0..f3e4e73c1 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include +#include #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 1cdd7e441..3614360fc 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -11,6 +11,7 @@ #include "PybindUtils.h" #include +#include #include "mlir-c/Bindings/Python/Interop.h" diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 2492ad5d1..b198b4b65 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -11,6 +11,7 @@ #include #include +#include #include "PybindUtils.h" diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 10527af6c..ad46e52b3 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -12,6 +12,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include namespace py = pybind11; using namespace mlir; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index f5364c864..13e8d2034 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -25,6 +25,7 @@ #include "mlir/Parser/Parser.h" #include +#include using namespace mlir; diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 6f81cd808..9ae425fe5 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/Pass/PassManager.h" +#include using namespace mlir; diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index ec2cf0175..61f1c9542 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -13,6 +13,7 @@ #include "mlir/CAPI/Wrap.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/ScopeExit.h" +#include using namespace mlir; From 7183b1008b9b42482288770b74fa9604fa283722 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 14 Jan 2023 01:25:58 -0800 Subject: [PATCH 408/915] [mlir] Use std::optional instead of llvm::Optional (NFC) This patch replaces (llvm::|)Optional< with std::optional<. I'll post a separate patch to remove #include "llvm/ADT/Optional.h". This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 --- .../Bindings/Python/DialectSparseTensor.cpp | 8 ++-- mlir/lib/Bindings/Python/Globals.h | 4 +- mlir/lib/Bindings/Python/IRAttributes.cpp | 7 ++-- mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++--------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +-- mlir/lib/Bindings/Python/IRModule.cpp | 4 +- mlir/lib/Bindings/Python/IRModule.h | 26 ++++++------- mlir/lib/Bindings/Python/IRTypes.cpp | 6 +-- mlir/lib/CAPI/IR/IR.cpp | 2 +- mlir/lib/CAPI/IR/Pass.cpp | 7 ++-- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 6 +-- 11 files changed, 58 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 6885ef9cf..da44141e2 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -34,8 +34,8 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { "get", [](py::object cls, std::vector dimLevelTypes, - llvm::Optional dimOrdering, - llvm::Optional higherOrdering, int pointerBitWidth, + std::optional dimOrdering, + std::optional higherOrdering, int pointerBitWidth, int indexBitWidth, MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( context, dimLevelTypes.size(), dimLevelTypes.data(), @@ -60,7 +60,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { }) .def_property_readonly( "dim_ordering", - [](MlirAttribute self) -> llvm::Optional { + [](MlirAttribute self) -> std::optional { MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimOrdering(self); if (mlirAffineMapIsNull(ret)) @@ -69,7 +69,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { }) .def_property_readonly( "higher_ordering", - [](MlirAttribute self) -> llvm::Optional { + [](MlirAttribute self) -> std::optional { MlirAffineMap ret = mlirSparseTensorEncodingAttrGetHigherOrdering(self); if (mlirAffineMapIsNull(ret)) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index fbe51a753..f3370a4f5 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -84,12 +84,12 @@ class PyGlobals { /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - llvm::Optional + std::optional lookupDialectClass(const std::string &dialectNamespace); /// Looks up a registered raw OpView class by operation name. Note that this /// may trigger a load of the dialect, which can arbitrarily re-enter. - llvm::Optional + std::optional lookupRawOpViewClass(llvm::StringRef operationName); private: diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c598aee7b..a29f16397 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -546,8 +546,9 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, Optional explicitType, - Optional> explicitShape, + getFromBuffer(py::buffer array, bool signless, + std::optional explicitType, + std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; @@ -573,7 +574,7 @@ class PyDenseElementsAttribute // Notably, this excludes, bool (which needs to be bit-packed) and // other exotics which do not have a direct representation in the buffer // protocol (i.e. complex, etc). - Optional bulkLoadElementType; + std::optional bulkLoadElementType; if (explicitType) { bulkLoadElementType = *explicitType; } else if (arrayInfo.format == "f") { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7a9e8eb3a..eb7d18f98 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1066,7 +1066,7 @@ void PyOperation::checkValid() const { } void PyOperationBase::print(py::object fileObject, bool binary, - llvm::Optional largeElementsLimit, + std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified) { @@ -1112,7 +1112,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject) { } py::object PyOperationBase::getAsm(bool binary, - llvm::Optional largeElementsLimit, + std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified) { @@ -1151,7 +1151,7 @@ void PyOperationBase::moveBefore(PyOperationBase &other) { operation.parentKeepAlive = otherOp.parentKeepAlive; } -llvm::Optional PyOperation::getParentOperation() { +std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); @@ -1163,7 +1163,7 @@ llvm::Optional PyOperation::getParentOperation() { PyBlock PyOperation::getBlock() { checkValid(); - llvm::Optional parentOperation = getParentOperation(); + std::optional parentOperation = getParentOperation(); MlirBlock block = mlirOperationGetBlock(get()); assert(!mlirBlockIsNull(block) && "Attached operation has null parent"); assert(parentOperation && "Operation has no parent"); @@ -1199,12 +1199,13 @@ static void maybeInsertOperation(PyOperationRef &op, } } -py::object PyOperation::create( - const std::string &name, llvm::Optional> results, - llvm::Optional> operands, - llvm::Optional attributes, - llvm::Optional> successors, int regions, - DefaultingPyLocation location, const py::object &maybeIp) { +py::object PyOperation::create(const std::string &name, + std::optional> results, + std::optional> operands, + std::optional attributes, + std::optional> successors, + int regions, DefaultingPyLocation location, + const py::object &maybeIp) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1357,12 +1358,13 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -py::object PyOpView::buildGeneric( - const py::object &cls, py::list resultTypeList, py::list operandList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, DefaultingPyLocation location, - const py::object &maybeIp) { +py::object +PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, + py::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, + DefaultingPyLocation location, + const py::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. std::string name = py::cast(cls.attr("OPERATION_NAME")); @@ -2518,7 +2520,7 @@ void mlir::python::populateIRCore(py::module &m) { .def_static( "fused", [](const std::vector &pyLocations, - llvm::Optional metadata, + std::optional metadata, DefaultingPyMlirContext context) { llvm::SmallVector locations; locations.reserve(pyLocations.size()); @@ -2533,7 +2535,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("context") = py::none(), kContextGetFusedLocationDocstring) .def_static( "name", - [](std::string name, llvm::Optional childLoc, + [](std::string name, std::optional childLoc, DefaultingPyMlirContext context) { return PyLocation( context->getRef(), diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index f3e4e73c1..fed8a5066 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -185,9 +185,9 @@ class PyInferTypeOpInterface /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on faliure. std::vector - inferReturnTypes(llvm::Optional> operands, - llvm::Optional attributes, - llvm::Optional> regions, + inferReturnTypes(std::optional> operands, + std::optional attributes, + std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { llvm::SmallVector mlirOperands; diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 3614360fc..e3b8ef189 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -112,7 +112,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } -llvm::Optional +std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); // Fast match against the class map first (common case). @@ -129,7 +129,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { return std::nullopt; } -llvm::Optional +std::optional PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { { auto foundIt = rawOpViewClassMapCache.find(operationName); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b198b4b65..d26fa2077 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -347,7 +347,7 @@ class PyDiagnosticHandler { private: MlirContext context; pybind11::object callback; - llvm::Optional registeredID; + std::optional registeredID; bool hadError = false; friend class PyMlirContext; }; @@ -504,11 +504,11 @@ class PyOperationBase { virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. void print(pybind11::object fileObject, bool binary, - llvm::Optional largeElementsLimit, bool enableDebugInfo, + std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified); pybind11::object getAsm(bool binary, - llvm::Optional largeElementsLimit, + std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified); @@ -586,7 +586,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Gets the parent operation or raises an exception if the operation has /// no parent. - llvm::Optional getParentOperation(); + std::optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. pybind11::object getCapsule(); @@ -598,10 +598,10 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an operation. See corresponding python docstring. static pybind11::object - create(const std::string &name, llvm::Optional> results, - llvm::Optional> operands, - llvm::Optional attributes, - llvm::Optional> successors, int regions, + create(const std::string &name, std::optional> results, + std::optional> operands, + std::optional attributes, + std::optional> successors, int regions, DefaultingPyLocation location, const pybind11::object &ip); /// Creates an OpView suitable for this operation. @@ -656,9 +656,9 @@ class PyOpView : public PyOperationBase { static pybind11::object buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, pybind11::list operandList, - llvm::Optional attributes, - llvm::Optional> successors, - llvm::Optional regions, DefaultingPyLocation location, + std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, const pybind11::object &maybeIp); private: @@ -738,10 +738,10 @@ class PyInsertionPoint { private: // Trampoline constructor that avoids null initializing members while // looking up parents. - PyInsertionPoint(PyBlock block, llvm::Optional refOperation) + PyInsertionPoint(PyBlock block, std::optional refOperation) : refOperation(std::move(refOperation)), block(std::move(block)) {} - llvm::Optional refOperation; + std::optional refOperation; PyBlock block; }; /// Wrapper around the generic MlirType. diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index ad46e52b3..3cc226d7a 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -401,8 +401,7 @@ class PyRankedTensorType c.def_static( "get", [](std::vector shape, PyType &elementType, - llvm::Optional &encodingAttr, - DefaultingPyLocation loc) { + std::optional &encodingAttr, DefaultingPyLocation loc) { MlirType t = mlirRankedTensorTypeGetChecked( loc, shape.size(), shape.data(), elementType, encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); @@ -423,8 +422,7 @@ class PyRankedTensorType py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); c.def_property_readonly( - "encoding", - [](PyRankedTensorType &self) -> llvm::Optional { + "encoding", [](PyRankedTensorType &self) -> std::optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 13e8d2034..68563a69c 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -289,7 +289,7 @@ void mlirOperationStateEnableResultTypeInference(MlirOperationState *state) { static LogicalResult inferOperationTypes(OperationState &state) { MLIRContext *context = state.getContext(); - Optional info = state.name.getRegisteredInfo(); + std::optional info = state.name.getRegisteredInfo(); if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 9ae425fe5..b92115411 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -112,7 +112,7 @@ namespace mlir { class ExternalPass : public Pass { public: ExternalPass(TypeID passID, StringRef name, StringRef argument, - StringRef description, Optional opName, + StringRef description, std::optional opName, ArrayRef dependentDialects, MlirExternalPassCallbacks callbacks, void *userData) : Pass(passID, opName), id(passID), name(name), argument(argument), @@ -143,7 +143,7 @@ class ExternalPass : public Pass { } bool canScheduleOn(RegisteredOperationName opName) const override { - if (Optional specifiedOpName = getOpName()) + if (std::optional specifiedOpName = getOpName()) return opName.getStringRef() == specifiedOpName; return true; } @@ -179,7 +179,8 @@ MlirPass mlirCreateExternalPass(MlirTypeID passID, MlirStringRef name, void *userData) { return wrap(static_cast(new mlir::ExternalPass( unwrap(passID), unwrap(name), unwrap(argument), unwrap(description), - opName.length > 0 ? Optional(unwrap(opName)) : std::nullopt, + opName.length > 0 ? std::optional(unwrap(opName)) + : std::nullopt, {dependentDialects, static_cast(nDependentDialects)}, callbacks, userData))); } diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 61f1c9542..5adccbdaf 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -19,7 +19,7 @@ using namespace mlir; bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID) { - Optional info = + std::optional info = unwrap(operation)->getRegisteredInfo(); return info && info->hasInterface(unwrap(interfaceTypeID)); } @@ -27,7 +27,7 @@ bool mlirOperationImplementsInterface(MlirOperation operation, bool mlirOperationImplementsInterfaceStatic(MlirStringRef operationName, MlirContext context, MlirTypeID interfaceTypeID) { - Optional info = RegisteredOperationName::lookup( + std::optional info = RegisteredOperationName::lookup( StringRef(operationName.data, operationName.length), unwrap(context)); return info && info->hasInterface(unwrap(interfaceTypeID)); } @@ -42,7 +42,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData) { StringRef name(opName.data, opName.length); - Optional info = + std::optional info = RegisteredOperationName::lookup(name, unwrap(context)); if (!info) return mlirLogicalResultFailure(); From eacf33e7370cfda57279da1b38675b57fc6da556 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 14 Jan 2023 01:34:49 -0800 Subject: [PATCH 409/915] [mlir] Remove remaining uses of llvm::Optional (NFC) This patch removes one "using" declaration and #include "llvm/ADT/Optional.h". It keeps several "using" declarations in headers for downstream users. This is part of an effort to migrate from llvm::Optional to std::optional: https://discourse.llvm.org/t/deprecating-llvm-optional-x-hasvalue-getvalue-getvalueor/63716 --- mlir/include/mlir/Bindings/Python/PybindAdaptors.h | 1 - mlir/lib/Bindings/Python/Globals.h | 1 - mlir/lib/Bindings/Python/IRAttributes.cpp | 1 - mlir/lib/Bindings/Python/IRModule.h | 1 - mlir/lib/Bindings/Python/PybindUtils.h | 1 - 5 files changed, 5 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 38c547078..98d80f010 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -25,7 +25,6 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/IR.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" namespace py = pybind11; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index f3370a4f5..8caa5a094 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -15,7 +15,6 @@ #include "PybindUtils.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index a29f16397..c8ede8b06 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -20,7 +20,6 @@ namespace py = pybind11; using namespace mlir; using namespace mlir::python; -using llvm::Optional; using llvm::SmallVector; using llvm::Twine; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d26fa2077..37115acbe 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -21,7 +21,6 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/Optional.h" namespace mlir { namespace python { diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 245dc4621..2d8bbc14c 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -10,7 +10,6 @@ #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #include "mlir-c/Support.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" From c637ae2d15cebece7ee8251a77eaf699074f7c76 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 19 Jan 2023 08:58:34 +0000 Subject: [PATCH 410/915] [mlir] fix python test It was using an incorrect attribute type, but the test was still passing because of the value being present in the output. --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 5fd5cfe10..2525ea34c 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -110,7 +110,7 @@ def __init__(self, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - interchange_attr = _get_int_array_attr(iterator_interchange) + interchange_attr = _get_dense_int64_array_attr(iterator_interchange) super().__init__( pdl_operation_type, _get_op_result_or_value(target), From f99c0cdfca91411a3706d6c81018455cb8a6d366 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 4 Jan 2023 14:04:53 +0000 Subject: [PATCH 411/915] [mlir] make multi-size tiling use transform parameters Use the recently introduced transform dialect parameter mechanism to perform controllable multi-size tiling with sizes computed at the transformation time rather than at runtime. This requires to generalize tile and split structured transform operations to work with any transform dialect handle types, which is desirable in itself to avoid unchecked overuse of PDL OperationType. Reviewed By: shabalin Differential Revision: https://reviews.llvm.org/D140980 --- .../dialects/_structured_transform_ops_ext.py | 73 ++++++++++++++----- 1 file changed, 55 insertions(+), 18 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 2525ea34c..f045e5c13 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -5,11 +5,11 @@ try: from ..ir import * from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl + from ..dialects import pdl, transform except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union +from typing import List, Optional, Sequence, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -51,13 +51,13 @@ def _get_int_array_attr( def _get_dense_int64_array_attr( values: Sequence[int]) -> DenseI64ArrayAttr: - """Creates a dense integer array from a sequence of integers. + """Creates a dense integer array from a sequence of integers. Expects the thread-local MLIR context to have been set by the context manager. """ - if values is None: - return DenseI64ArrayAttr.get([]) - return DenseI64ArrayAttr.get(values) + if values is None: + return DenseI64ArrayAttr.get([]) + return DenseI64ArrayAttr.get(values) def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, @@ -141,6 +141,7 @@ class MultiTileSizesOp: """Specialization for MultitileSizesOp class.""" def __init__(self, + result_type: Type, target: Union[Operation, Value], *, dimension: Union[int, IntegerAttr], @@ -149,9 +150,9 @@ def __init__(self, loc=None, ip=None): super().__init__( - pdl.OperationType.get(), - pdl.OperationType.get(), - pdl.OperationType.get(), + result_type, + result_type, + result_type, _get_op_result_or_value(target), dimension=_get_int64_attr(dimension), target_size=_get_int64_attr(target_size), @@ -223,11 +224,12 @@ def __init__(self, static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) dynamic_split_point = _get_op_result_or_value(split_point) - pdl_operation_type = pdl.OperationType.get() + target = _get_op_result_or_value(target) + super().__init__( - pdl_operation_type, - pdl_operation_type, - _get_op_result_or_value(target), + target.type, + target.type, + target, dimension=dimension, static_split_point=static_split_point, dynamic_split_point=dynamic_split_point, @@ -238,7 +240,9 @@ def __init__(self, class TileOp: """Specialization for TileOp class.""" + @overload def __init__(self, + loop_types: Union[Type, List[Type]], target: Union[Operation, Value], *, sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, @@ -246,9 +250,28 @@ def __init__(self, interchange: OptionalIntList = None, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - i64_type = IntegerType.get_signless(64) + ... + @overload + def __init__(self, + target: Union[Operation, Value], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): + ... + + def __init__(self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value]] = None, + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, + Value]], ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None): if sizes is None: sizes = [] @@ -267,12 +290,26 @@ def __init__(self, num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + + if isinstance(loop_types_or_target, (Operation, Value)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." + else: + loop_types = ([loop_types_or_target] * num_loops) if isinstance( + loop_types_or_target, Type) else loop_types_or_target + target = target_or_none + + target = _get_op_result_or_value(target) + super().__init__( - pdl_operation_type, [pdl_operation_type] * num_loops, - _get_op_result_or_value(target), + target.type, + loop_types, + target, dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_dense_int64_array_attr(interchange) if interchange else None, + interchange=_get_dense_int64_array_attr(interchange) + if interchange else None, loc=loc, ip=ip) From abfb9a0553b075b82107c01f7517ca1c167233d8 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 20 Jan 2023 16:02:31 +0000 Subject: [PATCH 412/915] [mlir] fix python types --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index f045e5c13..97705e2ad 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -254,7 +254,7 @@ def __init__(self, @overload def __init__(self, - target: Union[Operation, Value], + target: Union[Operation, Value, OpView], *, sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]] = None, @@ -265,7 +265,7 @@ def __init__(self, def __init__(self, loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value]] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, *, sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr]] = None, @@ -291,7 +291,7 @@ def __init__(self, num_loops = sum( v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) - if isinstance(loop_types_or_target, (Operation, Value)): + if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops target = loop_types_or_target assert target_or_none is None, "Cannot construct TileOp with two targets." From d983e89f2f6e632753edbe172614c673ec264644 Mon Sep 17 00:00:00 2001 From: Andrew Young Date: Tue, 24 Jan 2023 23:13:20 -0800 Subject: [PATCH 413/915] [MLIR] Expose LocationAttrs in the C API This patch adds three functions to the C API: - mlirAttributeIsALocation: returns true if the attribute is a LocationAttr, false otherwise. - mlirLocationGetAttribute: returns the underlying LocationAttr of a Location. - mlirLocationFromAttribute: gets a Location from a LocationAttr. Reviewed By: mikeurbach, Mogball Differential Revision: https://reviews.llvm.org/D142182 --- mlir/include/mlir-c/BuiltinAttributes.h | 6 ++++++ mlir/include/mlir-c/IR.h | 8 ++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 8 ++++++++ mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ 4 files changed, 30 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 3b1ac466d..2e6287939 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -25,6 +25,12 @@ extern "C" { /// Returns an empty attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirAttributeGetNull(void); +//===----------------------------------------------------------------------===// +// Location attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsALocation(MlirAttribute attr); + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7817da2bd..a349eb9f3 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -225,6 +225,14 @@ mlirDialectRegistryDestroy(MlirDialectRegistry registry); // Location API. //===----------------------------------------------------------------------===// +/// Returns the underlying location attribute of this location. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLocationGetAttribute(MlirLocation location); + +/// Creates a location from a location attribute. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationFromAttribute(MlirAttribute attribute); + /// Creates an File/Line/Column location owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet( MlirContext context, MlirStringRef filename, unsigned line, unsigned col); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index b6ee4af79..66d291edd 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -20,6 +20,14 @@ using namespace mlir; MlirAttribute mlirAttributeGetNull() { return {nullptr}; } +//===----------------------------------------------------------------------===// +// Location attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsALocation(MlirAttribute attr) { + return unwrap(attr).isa(); +} + //===----------------------------------------------------------------------===// // Affine map attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 68563a69c..7d3479736 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -145,6 +145,14 @@ void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { // Location API. //===----------------------------------------------------------------------===// +MlirAttribute mlirLocationGetAttribute(MlirLocation location) { + return wrap(LocationAttr(unwrap(location))); +} + +MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { + return wrap(Location(unwrap(attribute).cast())); +} + MlirLocation mlirLocationFileLineColGet(MlirContext context, MlirStringRef filename, unsigned line, unsigned col) { From f9e76b2f13593204c0e412f5d7564852a8a2674f Mon Sep 17 00:00:00 2001 From: Andrew Young Date: Thu, 19 Jan 2023 15:01:46 -0800 Subject: [PATCH 414/915] [MLIR] Add LocationAttr to the Python API This is a follow up to D142182, to expose LocationAttrs through Python. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D142522 --- mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index eb7d18f98..2ecfc36d4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2546,10 +2546,25 @@ void mlir::python::populateIRCore(py::module &m) { }, py::arg("name"), py::arg("childLoc") = py::none(), py::arg("context") = py::none(), kContextGetNameLocationDocString) + .def_static( + "from_attr", + [](PyAttribute &attribute, DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationFromAttribute(attribute)); + }, + py::arg("attribute"), py::arg("context") = py::none(), + "Gets a Location from a LocationAttr") .def_property_readonly( "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") + .def_property_readonly( + "attr", + [](PyLocation &self) { + return PyAttribute(self.getContext(), + mlirLocationGetAttribute(self)); + }, + "Get the underlying LocationAttr") .def( "emit_error", [](PyLocation &self, std::string message) { From 08be48a2b0b3867cbbc3021aa29e98fab23b33e4 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 25 Jan 2023 16:53:25 +0000 Subject: [PATCH 415/915] [mlir] multi-argument binding for top-level transform ops `applyTransforms` now takes an optional mapping to be associated with trailing block arguments of the top-level transform op, in addition to the payload root. This allows for more advanced forms of communication between C++ code and the transform dialect interpreter, in particular supplying operations without having to re-match them during interpretation. Reviewed By: shabalin Differential Revision: https://reviews.llvm.org/D142559 --- .../mlir/dialects/_transform_ops_ext.py | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 5cd57b050..593b8855c 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -89,7 +89,9 @@ def __init__(self, class SequenceOp: def __init__(self, failure_propagation_mode, results: Sequence[Type], - target: Union[Operation, Value, Type]): + target: Union[Operation, Value, Type], + extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], + Operation, OpView]] = None): root = _get_op_result_or_value(target) if isinstance( target, (Operation, Value)) else None root_type = root.type if not isinstance(target, Type) else target @@ -98,10 +100,25 @@ def __init__(self, failure_propagation_mode, results: Sequence[Type], IntegerType.get_signless(32), failure_propagation_mode._as_int()) else: failure_propagation_mode = failure_propagation_mode + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + super().__init__(results_=results, failure_propagation_mode=failure_propagation_mode_attr, - root=root) - self.regions[0].blocks.append(root_type) + root=root, + extra_bindings=extra_bindings) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) @property def body(self) -> Block: @@ -111,6 +128,10 @@ def body(self) -> Block: def bodyTarget(self) -> Value: return self.body.arguments[0] + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + class WithPDLPatternsOp: From 001eb3089fcd9be9c1f504b0e6ffd6fa72fdf488 Mon Sep 17 00:00:00 2001 From: Stella Stamenova Date: Wed, 1 Feb 2023 10:25:20 -0800 Subject: [PATCH 416/915] [mlir] Pin for the PyPi requirements for mlir This change is pinning the requirements to a specific version (or a range) depending on the requirement. A couple of considerations: * numpy 1.24 deprecates np.object, np.bool, np.float, np.complex, np.str, and np.int which are used heavily in onnx-mlir * not all versions of each package are available on every platform - to the best of my knowledge, these ranges should work on Ubuntu, CentOS and Windows Adding a minimum and maximum version, or pinning to a specific versions where possible, helps with two major goals - security and maintainability. It gives us an opportunity to make sure that the packages being used are not part of a security attack as well as guaranteeing that they support the features that mlir depends on (see note about numpy deprecation). Let me know if you are aware of better versions or ranges to pin to. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D142563 --- mlir/python/requirements.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 991e8eb24..aaf480f0b 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -numpy -pybind11>=2.8.0 -PyYAML -dataclasses +numpy>=1.19.5, <=1.23.5 +pybind11>=2.8.0, <=2.10.3 +PyYAML==6.0 +dataclasses>=0.6, <=0.8 From 8f0912c366e3fd53a13d355d6de326e050619066 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 2 Feb 2023 12:23:46 -0800 Subject: [PATCH 417/915] [mlir][py] Fix infer return type invocation for variadics Previously we only allowed the flattened list passed in, but the same input provided here as to buildGeneric so flatten accordingly. We have less info here than in buildGeneric so the error is more generic if unpacking fails. Differential Revision: https://reviews.llvm.org/D143240 --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 48 ++++++++++++++++++++--- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2 +- 2 files changed, 43 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index fed8a5066..b917bf0c1 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -12,6 +12,7 @@ #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Interfaces.h" +#include "llvm/ADT/STLExtras.h" namespace py = pybind11; @@ -183,9 +184,9 @@ class PyInferTypeOpInterface } /// Given the arguments required to build an operation, attempts to infer its - /// return types. Throws value_error on faliure. + /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional> operands, + inferReturnTypes(std::optional operandList, std::optional attributes, std::optional> regions, DefaultingPyMlirContext context, @@ -193,10 +194,45 @@ class PyInferTypeOpInterface llvm::SmallVector mlirOperands; llvm::SmallVector mlirRegions; - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue &value : *operands) { - mlirOperands.push_back(value); + if (operandList && !operandList->empty()) { + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto& it : llvm::enumerate(*operandList)) { + PyValue* val; + try { + val = py::cast(it.value()); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (py::cast_error &err) { + } + + try { + auto vals = py::cast(it.value()); + for (py::object v : vals) { + try { + val = py::cast(v); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + + ")") + .str()); + } + } + continue; + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + + throw py::cast_error(); } } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 505946ca1..63a3125ec 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -667,7 +667,7 @@ class IndexType(Type): class InferTypeOpInterface: def __init__(self, object: object, context: Optional[Context] = None) -> None: ... - def inferReturnTypes(self, operands: Optional[List[Value]] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... + def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @property def operation(self) -> Operation: ... @property From b452541dd3951a39ce9eb93300cb9f5551019c79 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 6 Feb 2023 17:44:47 -0800 Subject: [PATCH 418/915] [mlir][py] Fix unused var --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index b917bf0c1..c8371dcc7 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -206,6 +206,8 @@ class PyInferTypeOpInterface mlirOperands.push_back(val->get()); continue; } catch (py::cast_error &err) { + // Intentionally unhandled to try sequence below first. + (void)err; } try { From ca23f58b247fc7f54526bf73158f0db75efcbbcc Mon Sep 17 00:00:00 2001 From: Stella Stamenova Date: Tue, 7 Feb 2023 14:24:55 -0800 Subject: [PATCH 419/915] [mlir] Relax version requirement for PyYAML in mlir Some Ubuntu 20.04 images come with PyYAML 5.3.1 pre-installed through distutils. This makes pip very angry. See https://github.com/yaml/pyyaml/issues/349. Since older versions of PyYAML should work for mlir, relax the version requirement to ease developer setup. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D143523 --- mlir/python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index aaf480f0b..16a3dd39c 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.19.5, <=1.23.5 pybind11>=2.8.0, <=2.10.3 -PyYAML==6.0 +PyYAML>= 5.3.1, <=6.0 dataclasses>=0.6, <=0.8 From f02e498b44689aa109a2201f04b9d0b991e130dd Mon Sep 17 00:00:00 2001 From: Jake Hall Date: Mon, 13 Feb 2023 14:10:20 +0000 Subject: [PATCH 420/915] [mlir] Add Float8E5M2FNUZ and Float8E4M3FNUZ types to MLIR Float8E5M2FNUZ and Float8E4M3FNUZ have been added to APFloat in D141863. This change adds these types as MLIR builtin types alongside Float8E5M2 and Float8E4M3FN (added in D133823 and D138075). Reviewed By: krzysz00 Differential Revision: https://reviews.llvm.org/D143744 --- mlir/include/mlir-c/BuiltinTypes.h | 14 +++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 38 ++++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 16 ++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 16 ++++++++++ 4 files changed, 84 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 000397505..8b855d8c3 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -81,6 +81,20 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E5M2FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); + +/// Creates an f8E5M2FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); + +/// Checks whether the given type is an f8E4M3FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); + +/// Creates an f8E4M3FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 3cc226d7a..87ffe5936 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -139,6 +139,42 @@ class PyFloat8E5M2Type : public PyConcreteType { } }; +/// Floating Point Type subclass - Float8E4M3FNUZ. +class PyFloat8E4M3FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + } +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class PyFloat8E5M2FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -700,6 +736,8 @@ void mlir::python::populateIRTypes(py::module &m) { PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); + PyFloat8E4M3FNUZType::bind(m); + PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 73a3ec414..aea122120 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -84,6 +84,22 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { + return unwrap(type).isFloat8E5M2FNUZ(); +} + +MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); +} + +bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { + return unwrap(type).isFloat8E4M3FNUZ(); +} + +MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 63a3125ec..7d5ff23f6 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -52,6 +52,8 @@ __all__ = [ "DictAttr", "Float8E4M3FNType", "Float8E5M2Type", + "Float8E4M3FNUZType", + "Float8E5M2FNUZType", "F16Type", "F32Type", "F64Type", @@ -593,6 +595,20 @@ class Float8E5M2Type(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E4M3FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + +class Float8E5M2FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E5M2FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F16Type(Type): def __init__(self, cast_from_type: Type) -> None: ... From ae32e12fe2a26e00951df45fa3cf3b29cddb48f9 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 14 Feb 2023 18:20:45 -0800 Subject: [PATCH 421/915] [mlir][sparse] Factoring out SparseTensorType class This change adds a new `SparseTensorType` class for making the "dim" vs "lvl" distinction more overt, and for abstracting over the differences between sparse-tensors and dense-tensors. In addition, this change also adds new type aliases `Dimension`, `Level`, and `FieldIndex` to make code more self-documenting. Although the diff is very large, the majority of the changes are mechanical in nature (e.g., changing types to use the new aliases, updating variable names to match, etc). Along the way I also made many variables `const` when they could be; the majority of which required only adding the keyword. A few places had conditional definitions of these variables, requiring actual code changes; however, that was only done when the overall change was extremely local and easy to extract. All these changes are included in the current patch only because it would be too onerous to split them off into a separate patch. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D143800 --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index b667ad3c6..831cdd8a4 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -70,13 +70,13 @@ mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { } intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { - return unwrap(attr).cast().getDimLevelType().size(); + return unwrap(attr).cast().getLvlRank(); } MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos) { +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getDimLevelType()[pos]); + unwrap(attr).cast().getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { From 17a6436d840ea12276b701a4cc2919097b346d1a Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Sat, 25 Feb 2023 03:51:31 -0500 Subject: [PATCH 422/915] [mlir][python] Don't emit diagnostics when printing invalid ops The asm printer grew the ability to automatically fall back to the generic format for invalid ops, so this logic doesn't need to be in the bindings anymore. The printer already handles supressing diagnostics that get emitted while checking if the op is valid. Reviewed By: mehdi_amini, stellaraccident Differential Revision: https://reviews.llvm.org/D144805 --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 13 ++----------- mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index a349eb9f3..023b99f42 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -404,6 +404,10 @@ mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); +/// Do not verify the operation when using custom operation printers. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2ecfc36d4..e09f0fdee 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1075,17 +1075,6 @@ void PyOperationBase::print(py::object fileObject, bool binary, if (fileObject.is_none()) fileObject = py::module::import("sys").attr("stdout"); - if (!assumeVerified && !printGenericOpForm && - !mlirOperationVerify(operation)) { - std::string message("// Verification failed, printing generic form\n"); - if (binary) { - fileObject.attr("write")(py::bytes(message)); - } else { - fileObject.attr("write")(py::str(message)); - } - printGenericOpForm = true; - } - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); @@ -1096,6 +1085,8 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsPrintGenericOpForm(flags); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); + if (assumeVerified) + mlirOpPrintingFlagsAssumeVerified(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 7d3479736..e83f0f824 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -141,6 +141,10 @@ void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { unwrap(flags)->useLocalScope(); } +void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { + unwrap(flags)->assumeVerified(); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// From c83e5a9c51648a2383e1c548db7f84624083a914 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 24 Feb 2023 03:59:08 -0800 Subject: [PATCH 423/915] [mlir][Linalg] Refactor transform.structured.pad to separate out hoisting Depends on: D144717 Differential Revision: https://reviews.llvm.org/D144856 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 97705e2ad..e2c262ca5 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -171,7 +171,6 @@ def __init__(self, Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, pack_paddings: OptionalIntList = None, - hoist_paddings: OptionalIntList = None, transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ ArrayAttr, IntOrAttrList]]]] = None, loc=None, @@ -180,7 +179,6 @@ def __init__(self, padding_values_attr = _get_array_attr(padding_values) padding_dimensions_attr = _get_int_array_attr(padding_dimensions) pack_paddings_attr = _get_int_array_attr(pack_paddings) - hoist_paddings_attr = _get_int_array_attr(hoist_paddings) transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) super().__init__( pdl_operation_type, @@ -188,7 +186,6 @@ def __init__(self, padding_values=padding_values_attr, padding_dimensions=padding_dimensions_attr, pack_paddings=pack_paddings_attr, - hoist_paddings=hoist_paddings_attr, transpose_paddings=transpose_paddings_attr, loc=loc, ip=ip) From 4540a9a609530479044125ccdcecffd33badb52b Mon Sep 17 00:00:00 2001 From: rkayaith Date: Tue, 8 Nov 2022 16:55:06 -0500 Subject: [PATCH 424/915] [mlir][python] Add generic operation parse APIs Currently the bindings only allow for parsing IR with a top-level `builtin.module` op, since the parse APIs insert an implicit module op. This change adds `Operation.parse`, which returns whatever top-level op is actually in the source. To simplify parsing of specific operations, `OpView.parse` is also added, which handles the error checking for `OpView` subclasses. Reviewed By: ftynse, stellaraccident Differential Revision: https://reviews.llvm.org/D143352 --- mlir/include/mlir-c/IR.h | 10 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 50 ++++++++++++++++++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 8 ++++- mlir/lib/CAPI/IR/IR.cpp | 9 ++++++ 4 files changed, 75 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 023b99f42..84d226b40 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -422,6 +422,16 @@ mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); /// - Result type inference is enabled and cannot be performed. MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreate(MlirOperationState *state); +/// Parses an operation, giving ownership to the caller. If parsing fails a null +/// operation will be returned, and an error diagnostic emitted. +/// +/// `sourceStr` may be either the text assembly format, or binary bytecode +/// format. `sourceName` is used as the file name of the source; any IR without +/// locations will get a `FileLineColLoc` location with `sourceName` as the file +/// name. +MLIR_CAPI_EXPORTED MlirOperation mlirOperationCreateParse( + MlirContext context, MlirStringRef sourceStr, MlirStringRef sourceName); + /// Creates a deep copy of an operation. The operation is not inserted and /// ownership is transferred to the caller. MLIR_CAPI_EXPORTED MlirOperation mlirOperationClone(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e09f0fdee..12d37da5b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -20,8 +20,8 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include #include +#include namespace py = pybind11; using namespace mlir; @@ -1059,6 +1059,19 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, return created; } +PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName) { + MlirOperation op = + mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), + toMlirStringRef(sourceName)); + // TODO: Include error diagnostic messages in the exception message + if (mlirOperationIsNull(op)) + throw py::value_error( + "Unable to parse operation assembly (see diagnostics)"); + return PyOperation::createDetached(std::move(contextRef), op); +} + void PyOperation::checkValid() const { if (!valid) { throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); @@ -2769,6 +2782,17 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), kOperationCreateDocstring) + .def_static( + "parse", + [](const std::string &sourceStr, const std::string &sourceName, + DefaultingPyMlirContext context) { + return PyOperation::parse(context->getRef(), sourceStr, sourceName) + ->createOpView(); + }, + py::arg("source"), py::kw_only(), py::arg("source_name") = "", + py::arg("context") = py::none(), + "Parses an operation. Supports both text assembly format and binary " + "bytecode format.") .def_property_readonly("parent", [](PyOperation &self) -> py::object { auto parent = self.getParentOperation(); @@ -2820,6 +2844,30 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = py::none(), py::arg("loc") = py::none(), py::arg("ip") = py::none(), "Builds a specific, generated OpView based on class level attributes."); + opViewClass.attr("parse") = classmethod( + [](const py::object &cls, const std::string &sourceStr, + const std::string &sourceName, DefaultingPyMlirContext context) { + PyOperationRef parsed = + PyOperation::parse(context->getRef(), sourceStr, sourceName); + + // Check if the expected operation was parsed, and cast to to the + // appropriate `OpView` subclass if successful. + // NOTE: This accesses attributes that have been automatically added to + // `OpView` subclasses, and is not intended to be used on `OpView` + // directly. + std::string clsOpName = + py::cast(cls.attr("OPERATION_NAME")); + MlirStringRef parsedOpName = + mlirIdentifierStr(mlirOperationGetName(*parsed.get())); + if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) + throw py::value_error( + "Expected a '" + clsOpName + "' op, got: '" + + std::string(parsedOpName.data, parsedOpName.length) + "'"); + return cls.attr("_Raw")(parsed.getObject()); + }, + py::arg("cls"), py::arg("source"), py::kw_only(), + py::arg("source_name") = "", py::arg("context") = py::none(), + "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 37115acbe..fa4bc1c3d 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -9,9 +9,9 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include #include #include -#include #include "PybindUtils.h" @@ -548,6 +548,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject { createDetached(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Parses a source string (either text assembly or bytecode), creating a + /// detached operation. + static PyOperationRef parse(PyMlirContextRef contextRef, + const std::string &sourceStr, + const std::string &sourceName); + /// Detaches the operation from its parent block and updates its state /// accordingly. void detachFromParent() { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e83f0f824..051559acd 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -368,6 +368,15 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) { return result; } +MlirOperation mlirOperationCreateParse(MlirContext context, + MlirStringRef sourceStr, + MlirStringRef sourceName) { + + return wrap( + parseSourceString(unwrap(sourceStr), unwrap(context), unwrap(sourceName)) + .release()); +} + MlirOperation mlirOperationClone(MlirOperation op) { return wrap(unwrap(op)->clone()); } From 1b164aaf58d7a87d95f736d5bc209c244f689e2a Mon Sep 17 00:00:00 2001 From: rkayaith Date: Tue, 8 Nov 2022 22:39:18 -0500 Subject: [PATCH 425/915] [mlir][CAPI] Allow running pass manager on any operation `mlirPassManagerRun` is currently restricted to running on `builtin.module` ops, but this restriction doesn't exist on the C++ side. This renames it to `mlirPassManagerRunOnOp` and updates it to take `MlirOperation` instead of `MlirModule`. Depends on D143352 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D143354 --- mlir/include/mlir-c/Pass.h | 4 ++-- mlir/lib/Bindings/Python/Pass.cpp | 4 ++-- mlir/lib/CAPI/IR/Pass.cpp | 6 +++--- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 721f1f28f..35db13830 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -70,9 +70,9 @@ static inline bool mlirPassManagerIsNull(MlirPassManager passManager) { MLIR_CAPI_EXPORTED MlirOpPassManager mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); -/// Run the provided `passManager` on the given `module`. +/// Run the provided `passManager` on the given `op`. MLIR_CAPI_EXPORTED MlirLogicalResult -mlirPassManagerRun(MlirPassManager passManager, MlirModule module); +mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); /// Enable mlir-print-ir-after-all. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index cb3c1586e..99f17a18b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -117,8 +117,8 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { .def( "run", [](PyPassManager &passManager, PyModule &module) { - MlirLogicalResult status = - mlirPassManagerRun(passManager.get(), module.get()); + MlirLogicalResult status = mlirPassManagerRunOnOp( + passManager.get(), mlirModuleGetOperation(module.get())); if (mlirLogicalResultIsFailure(status)) throw SetPyError(PyExc_RuntimeError, "Failure while executing pass pipeline."); diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index b92115411..d242baae9 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -39,9 +39,9 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager) { return wrap(static_cast(unwrap(passManager))); } -MlirLogicalResult mlirPassManagerRun(MlirPassManager passManager, - MlirModule module) { - return wrap(unwrap(passManager)->run(unwrap(module))); +MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, + MlirOperation op) { + return wrap(unwrap(passManager)->run(unwrap(op))); } void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { From 760ca2ae4f0cdd8d4a6ec4fd2573fd853b7d3e82 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Tue, 8 Nov 2022 22:48:26 -0500 Subject: [PATCH 426/915] [mlir][python] Allow running pass manager on any operation `PassManager.run` is currently restricted to running on `builtin.module` ops, but this restriction doesn't exist on the C++ side. This updates it to take `ir.Operation/OpView` instead of `ir.Module`. Depends on D143354 Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D143356 --- mlir/lib/Bindings/Python/Pass.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 99f17a18b..7e90d8be6 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -116,16 +116,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyModule &module) { + [](PyPassManager &passManager, PyOperationBase &op) { MlirLogicalResult status = mlirPassManagerRunOnOp( - passManager.get(), mlirModuleGetOperation(module.get())); + passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) throw SetPyError(PyExc_RuntimeError, "Failure while executing pass pipeline."); }, - py::arg("module"), - "Run the pass manager on the provided module, throw a RuntimeError " - "on failure.") + py::arg("operation"), + "Run the pass manager on the provided operation, throw a " + "RuntimeError on failure.") .def( "__str__", [](PyPassManager &self) { From 994ef3af12a6dc5d929e1bcc65984a1c31f5885b Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Sun, 22 Jan 2023 23:31:18 -0500 Subject: [PATCH 427/915] [mlir][python] Remove "Raw" OpView classes The raw `OpView` classes are used to bypass the constructors of `OpView` subclasses, but having a separate class can create some confusing behaviour, e.g.: ``` op = MyOp(...) # fails, lhs is 'MyOp', rhs is '_MyOp' assert type(op) == type(op.operation.opview) ``` Instead we can use `__new__` to achieve the same thing without a separate class: ``` my_op = MyOp.__new__(MyOp) OpView.__init__(my_op, op) ``` Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D143830 --- mlir/lib/Bindings/Python/Globals.h | 23 ++++---- mlir/lib/Bindings/Python/IRCore.cpp | 54 ++++++------------- mlir/lib/Bindings/Python/IRModule.cpp | 20 ++++--- mlir/lib/Bindings/Python/IRModule.h | 12 ++++- mlir/lib/Bindings/Python/MainModule.cpp | 10 +--- .../python/mlir/_mlir_libs/_mlir/__init__.pyi | 2 +- 6 files changed, 45 insertions(+), 76 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 8caa5a094..45d036896 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -74,8 +74,7 @@ class PyGlobals { /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, - pybind11::object rawOpViewClass); + pybind11::object pyClass); /// Returns the custom Attribute builder for Attribute kind. std::optional @@ -86,10 +85,11 @@ class PyGlobals { std::optional lookupDialectClass(const std::string &dialectNamespace); - /// Looks up a registered raw OpView class by operation name. Note that this - /// may trigger a load of the dialect, which can arbitrarily re-enter. + /// Looks up a registered operation class (deriving from OpView) by operation + /// name. Note that this may trigger a load of the dialect, which can + /// arbitrarily re-enter. std::optional - lookupRawOpViewClass(llvm::StringRef operationName); + lookupOperationClass(llvm::StringRef operationName); private: static PyGlobals *instance; @@ -99,21 +99,16 @@ class PyGlobals { llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. llvm::StringMap operationClassMap; - /// Map of operation name to custom subclass that directly initializes - /// the OpView base class (bypassing the user class constructor). - llvm::StringMap rawOpViewClassMap; /// Map of attribute ODS name to custom builder. llvm::StringMap attributeBuilderMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModulesCache; - /// Cache of operation name to custom OpView subclass that directly - /// initializes the OpView base class (or an undefined object for negative - /// lookup). This is maintained on loopup as a shadow of rawOpViewClassMap - /// in order for repeat lookups of the OpView classes to only incur the cost - /// of one hashtable lookup. - llvm::StringMap rawOpViewClassMapCache; + /// Cache of operation name to external operation class object. This is + /// maintained on lookup as a shadow of operationClassMap in order for repeat + /// lookups of the classes to only incur the cost of one hashtable lookup. + llvm::StringMap operationClassMapCache; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 12d37da5b..e03b6470c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1339,10 +1339,10 @@ py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); - auto opViewClass = PyGlobals::get().lookupRawOpViewClass( + auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); - if (opViewClass) - return (*opViewClass)(getRef().getObject()); + if (operationCls) + return PyOpView::constructDerived(*operationCls, *getRef().get()); return py::cast(PyOpView(getRef().getObject())); } @@ -1618,47 +1618,23 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, /*regions=*/*regions, location, maybeIp); } +pybind11::object PyOpView::constructDerived(const pybind11::object &cls, + const PyOperation &operation) { + // TODO: pybind11 2.6 supports a more direct form. + // Upgrade many years from now. + // auto opViewType = py::type::of(); + py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); + py::object instance = cls.attr("__new__")(cls); + opViewType.attr("__init__")(instance, operation); + return instance; +} + PyOpView::PyOpView(const py::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. : operation(py::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} -py::object PyOpView::createRawSubclass(const py::object &userClass) { - // This is... a little gross. The typical pattern is to have a pure python - // class that extends OpView like: - // class AddFOp(_cext.ir.OpView): - // def __init__(self, loc, lhs, rhs): - // operation = loc.context.create_operation( - // "addf", lhs, rhs, results=[lhs.type]) - // super().__init__(operation) - // - // I.e. The goal of the user facing type is to provide a nice constructor - // that has complete freedom for the op under construction. This is at odds - // with our other desire to sometimes create this object by just passing an - // operation (to initialize the base class). We could do *arg and **kwargs - // munging to try to make it work, but instead, we synthesize a new class - // on the fly which extends this user class (AddFOp in this example) and - // *give it* the base class's __init__ method, thus bypassing the - // intermediate subclass's __init__ method entirely. While slightly, - // underhanded, this is safe/legal because the type hierarchy has not changed - // (we just added a new leaf) and we aren't mucking around with __new__. - // Typically, this new class will be stored on the original as "_Raw" and will - // be used for casts and other things that need a variant of the class that - // is initialized purely from an operation. - py::object parentMetaclass = - py::reinterpret_borrow((PyObject *)&PyType_Type); - py::dict attributes; - // TODO: pybind11 2.6 supports a more direct form. Upgrade many years from - // now. - // auto opViewType = py::type::of(); - auto opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - attributes["__init__"] = opViewType.attr("__init__"); - py::str origName = userClass.attr("__name__"); - py::str newName = py::str("_") + origName; - return parentMetaclass(newName, py::make_tuple(userClass), attributes); -} - //------------------------------------------------------------------------------ // PyInsertionPoint. //------------------------------------------------------------------------------ @@ -2863,7 +2839,7 @@ void mlir::python::populateIRCore(py::module &m) { throw py::value_error( "Expected a '" + clsOpName + "' op, got: '" + std::string(parsedOpName.data, parsedOpName.length) + "'"); - return cls.attr("_Raw")(parsed.getObject()); + return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), py::arg("source_name") = "", py::arg("context") = py::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e3b8ef189..7221442e4 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -84,8 +84,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, - py::object rawOpViewClass) { + py::object pyClass) { py::object &found = operationClassMap[operationName]; if (found) { throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + @@ -93,7 +92,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, "' is already registered."); } found = std::move(pyClass); - rawOpViewClassMap[operationName] = std::move(rawOpViewClass); } std::optional @@ -130,10 +128,10 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { } std::optional -PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { +PyGlobals::lookupOperationClass(llvm::StringRef operationName) { { - auto foundIt = rawOpViewClassMapCache.find(operationName); - if (foundIt != rawOpViewClassMapCache.end()) { + auto foundIt = operationClassMapCache.find(operationName); + if (foundIt != operationClassMapCache.end()) { if (foundIt->second.is_none()) return std::nullopt; assert(foundIt->second && "py::object is defined"); @@ -148,22 +146,22 @@ PyGlobals::lookupRawOpViewClass(llvm::StringRef operationName) { // Attempt to find from the canonical map and cache. { - auto foundIt = rawOpViewClassMap.find(operationName); - if (foundIt != rawOpViewClassMap.end()) { + auto foundIt = operationClassMap.find(operationName); + if (foundIt != operationClassMap.end()) { if (foundIt->second.is_none()) return std::nullopt; assert(foundIt->second && "py::object is defined"); // Positive cache. - rawOpViewClassMapCache[operationName] = foundIt->second; + operationClassMapCache[operationName] = foundIt->second; return foundIt->second; } // Negative cache. - rawOpViewClassMap[operationName] = py::none(); + operationClassMap[operationName] = py::none(); return std::nullopt; } } void PyGlobals::clearImportCache() { loadedDialectModulesCache.clear(); - rawOpViewClassMapCache.clear(); + operationClassMapCache.clear(); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fa4bc1c3d..4aced3639 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -654,8 +654,6 @@ class PyOpView : public PyOperationBase { PyOpView(const pybind11::object &operationObject); PyOperation &getOperation() override { return operation; } - static pybind11::object createRawSubclass(const pybind11::object &userClass); - pybind11::object getOperationObject() { return operationObject; } static pybind11::object @@ -666,6 +664,16 @@ class PyOpView : public PyOperationBase { std::optional regions, DefaultingPyLocation location, const pybind11::object &maybeIp); + /// Construct an instance of a class deriving from OpView, bypassing its + /// `__init__` method. The derived class will typically define a constructor + /// that provides a convenient builder, but we need to side-step this when + /// constructing an `OpView` for an already-built operation. + /// + /// The caller is responsible for verifying that `operation` is a valid + /// operation to construct `cls` with. + static pybind11::object constructDerived(const pybind11::object &cls, + const PyOperation &operation); + private: PyOperation &operation; // For efficient, cast-free access from C++ pybind11::object operationObject; // Holds the reference. diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 1d6d8fa01..b32b4186f 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -41,7 +41,6 @@ PYBIND11_MODULE(_mlir, m) { "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, py::arg("operation_name"), py::arg("operation_class"), - py::arg("raw_opview_class"), "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -68,18 +67,11 @@ PYBIND11_MODULE(_mlir, m) { [dialectClass](py::object opClass) -> py::object { std::string operationName = opClass.attr("OPERATION_NAME").cast(); - auto rawSubclass = PyOpView::createRawSubclass(opClass); - PyGlobals::get().registerOperationImpl(operationName, opClass, - rawSubclass); + PyGlobals::get().registerOperationImpl(operationName, opClass); // Dict-stuff the new opClass by name onto the dialect class. py::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; - - // Now create a special "Raw" subclass that passes through - // construction to the OpView parent (bypasses the intermediate - // child's __init__). - opClass.attr("_Raw") = rawSubclass; return opClass; }); }, diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index c8734cfde..93b98c4aa 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -5,7 +5,7 @@ globals: "_Globals" class _Globals: dialect_search_modules: List[str] def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... - def _register_operation_impl(self, operation_name: str, operation_class: type, raw_opview_class: type) -> None: ... + def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... def append_dialect_search_prefix(self, module_name: str) -> None: ... def register_dialect(dialect_class: type) -> object: ... From 3846b1b20082c5b465add4163cdbdbef138cae70 Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Mon, 6 Mar 2023 12:19:41 -0800 Subject: [PATCH 428/915] [mlir][sparse] Renaming "pointer/index" to "position/coordinate" The old "pointer/index" names often cause confusion since these names clash with names of unrelated things in MLIR; so this change rectifies this by changing everything to use "position/coordinate" terminology instead. In addition to the basic terminology, there have also been various conventions for making certain distinctions like: (1) the overall storage for coordinates in the sparse-tensor, vs the particular collection of coordinates of a given element; and (2) particular coordinates given as a `Value` or `TypedValue`, vs particular coordinates given as `ValueRange` or similar. I have striven to maintain these distinctions as follows: * "p/c" are used for individual position/coordinate values, when there is no risk of confusion. (Just like we use "d/l" to abbreviate "dim/lvl".) * "pos/crd" are used for individual position/coordinate values, when a longer name is helpful to avoid ambiguity or to form compound names (e.g., "parentPos"). (Just like we use "dim/lvl" when we need a longer form of "d/l".) I have also used these forms for a handful of compound names where the old name had been using a three-letter form previously, even though a longer form would be more appropriate. I've avoided renaming these to use a longer form purely for expediency sake, since changing them would require a cascade of other renamings. They should be updated to follow the new naming scheme, but that can be done in future patches. * "coords" is used for the complete collection of crd values associated with a single element. In the runtime library this includes both `std::vector` and raw pointer representations. In the compiler, this is used specifically for buffer variables with C++ type `Value`, `TypedValue`, etc. The bare form "coords" is discouraged, since it fails to make the dim/lvl distinction; so the compound names "dimCoords/lvlCoords" should be used instead. (Though there may exist a rare few cases where is is appropriate to be intentionally ambiguous about what coordinate-space the coords live in; in which case the bare "coords" is appropriate.) There is seldom the need for the pos variant of this notion. In most circumstances we use the term "cursor", since the same buffer is reused for a 'moving' pos-collection. * "dcvs/lcvs" is used in the compiler as the `ValueRange` analogue of "dimCoords/lvlCoords". (The "vs" stands for "`Value`s".) I haven't found the need for it, but "pvs" would be the obvious name for a pos-`ValueRange`. The old "ind"-vs-"ivs" naming scheme does not seem to have been sustained in more recent code, which instead prefers other mnemonics (e.g., adding "Buf" to the end of the names for `TypeValue`). I have cleaned up a lot of these to follow the "coords"-vs-"cvs" naming scheme, though haven't done an exhaustive cleanup. * "positions/coordinates" are used for larger collections of pos/crd values; in particular, these are used when referring to the complete sparse-tensor storage components. I also prefer to use these unabbreviated names in the documentation, unless there is some specific reason why using the abbreviated forms helps resolve ambiguity. In addition to making this terminology change, this change also does some cleanup along the way: * correcting the dim/lvl terminology in certain places. * adding `const` when it requires no other code changes. * miscellaneous cleanup that was entailed in order to make the proper distinctions. Most of these are in CodegenUtils.{h,cpp} Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D144773 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 30 +++++++++---------- .../Bindings/Python/DialectSparseTensor.cpp | 30 ++++++++----------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 24 +++++++-------- 3 files changed, 40 insertions(+), 44 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 8027f319b..7d560dd80 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -41,40 +41,40 @@ enum MlirSparseTensorDimLevelType { // SparseTensorEncodingAttr //===----------------------------------------------------------------------===// -/// Checks whether the given attribute is a sparse_tensor.encoding attribute. +/// Checks whether the given attribute is a `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); -/// Creates a sparse_tensor.encoding attribute with the given parameters. +/// Creates a `sparse_tensor.encoding` attribute with the given parameters. MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( - MlirContext ctx, intptr_t numDimLevelTypes, + MlirContext ctx, intptr_t lvlRank, enum MlirSparseTensorDimLevelType const *dimLevelTypes, - MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, - int pointerBitWidth, int indexBitWidth); + MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, + int crdWidth); -/// Returns the number of dim level types in a sparse_tensor.encoding attribute. +/// Returns the level-rank of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED intptr_t -mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr); +mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); -/// Returns a specified dim level type in a sparse_tensor.encoding attribute. +/// Returns a specified level-type of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t pos); +mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl); -/// Returns the dimension ordering in a sparse_tensor.encoding attribute. +/// Returns the dimension-ordering of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr); -/// Returns the higher ordering in a sparse_tensor.encoding attribute. +/// Returns the higher-ordering of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr); -/// Returns the pointer bit width in a sparse_tensor.encoding attribute. +/// Returns the position bitwidth of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED int -mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr); +mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr); -/// Returns the index bit width in a sparse_tensor.encoding attribute. +/// Returns the coordinate bitwidth of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED int -mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr); +mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr); #ifdef __cplusplus } diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index da44141e2..e84937df8 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -35,27 +35,27 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { [](py::object cls, std::vector dimLevelTypes, std::optional dimOrdering, - std::optional higherOrdering, int pointerBitWidth, - int indexBitWidth, MlirContext context) { + std::optional higherOrdering, int posWidth, + int crdWidth, MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( context, dimLevelTypes.size(), dimLevelTypes.data(), dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, - pointerBitWidth, indexBitWidth)); + posWidth, crdWidth)); }, py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), - py::arg("higher_ordering"), py::arg("pointer_bit_width"), - py::arg("index_bit_width"), py::arg("context") = py::none(), + py::arg("higher_ordering"), py::arg("pos_width"), + py::arg("crd_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( "dim_level_types", [](MlirAttribute self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); std::vector ret; - for (int i = 0, - e = mlirSparseTensorEncodingGetNumDimLevelTypes(self); - i < e; ++i) + ret.reserve(lvlRank); + for (int l = 0; l < lvlRank; ++l) ret.push_back( - mlirSparseTensorEncodingAttrGetDimLevelType(self, i)); + mlirSparseTensorEncodingAttrGetDimLevelType(self, l)); return ret; }) .def_property_readonly( @@ -76,14 +76,10 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return {}; return ret; }) - .def_property_readonly( - "pointer_bit_width", - [](MlirAttribute self) { - return mlirSparseTensorEncodingAttrGetPointerBitWidth(self); - }) - .def_property_readonly("index_bit_width", [](MlirAttribute self) { - return mlirSparseTensorEncodingAttrGetIndexBitWidth(self); - }); + .def_property_readonly("pos_width", + mlirSparseTensorEncodingAttrGetPosWidth) + .def_property_readonly("crd_width", + mlirSparseTensorEncodingAttrGetCrdWidth); } PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 831cdd8a4..1aa6d329d 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -46,17 +46,17 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { } MlirAttribute mlirSparseTensorEncodingAttrGet( - MlirContext ctx, intptr_t numDimLevelTypes, + MlirContext ctx, intptr_t lvlRank, MlirSparseTensorDimLevelType const *dimLevelTypes, - MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, - int pointerBitWidth, int indexBitWidth) { + MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, + int crdWidth) { SmallVector cppDimLevelTypes; - cppDimLevelTypes.resize(numDimLevelTypes); - for (intptr_t i = 0; i < numDimLevelTypes; ++i) - cppDimLevelTypes[i] = static_cast(dimLevelTypes[i]); + cppDimLevelTypes.reserve(lvlRank); + for (intptr_t l = 0; l < lvlRank; ++l) + cppDimLevelTypes.push_back(static_cast(dimLevelTypes[l])); return wrap(SparseTensorEncodingAttr::get( unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), - unwrap(higherOrdering), pointerBitWidth, indexBitWidth)); + unwrap(higherOrdering), posWidth, crdWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { @@ -69,7 +69,7 @@ mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { unwrap(attr).cast().getHigherOrdering()); } -intptr_t mlirSparseTensorEncodingGetNumDimLevelTypes(MlirAttribute attr) { +intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { return unwrap(attr).cast().getLvlRank(); } @@ -79,10 +79,10 @@ mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { unwrap(attr).cast().getLvlType(lvl)); } -int mlirSparseTensorEncodingAttrGetPointerBitWidth(MlirAttribute attr) { - return unwrap(attr).cast().getPointerBitWidth(); +int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { + return unwrap(attr).cast().getPosWidth(); } -int mlirSparseTensorEncodingAttrGetIndexBitWidth(MlirAttribute attr) { - return unwrap(attr).cast().getIndexBitWidth(); +int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { + return unwrap(attr).cast().getCrdWidth(); } From efb1efd6ee18725dd81072f78b64808fb93f39f1 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 7 Feb 2023 16:07:50 -0500 Subject: [PATCH 429/915] [mlir][python] Capture error diagnostics in exceptions This updates most (all?) error-diagnostic-emitting python APIs to capture error diagnostics and include them in the raised exception's message: ``` >>> Operation.parse('"arith.addi"() : () -> ()')) Traceback (most recent call last): File "", line 1, in mlir._mlir_libs.MLIRError: Unable to parse operation assembly: error: "-":1:1: 'arith.addi' op requires one result note: "-":1:1: see current operation: "arith.addi"() : () -> () ``` The diagnostic information is available on the exception for users who may want to customize the error message: ``` >>> try: ... Operation.parse('"arith.addi"() : () -> ()') ... except MLIRError as e: ... print(e.message) ... print(e.error_diagnostics) ... print(e.error_diagnostics[0].message) ... Unable to parse operation assembly [] 'arith.addi' op requires one result ``` Error diagnostics captured in exceptions aren't propagated to diagnostic handlers, to avoid double-reporting of errors. The context-level `emit_error_diagnostics` option can be used to revert to the old behaviour, causing error diagnostics to be reported to handlers instead of as part of exceptions. API changes: - `Operation.verify` now raises an exception on verification failure, instead of returning `false` - The exception raised by the following methods has been changed to `MLIRError`: - `PassManager.run` - `{Module,Operation,Type,Attribute}.parse` - `{RankedTensorType,UnrankedTensorType}.get` - `{MemRefType,UnrankedMemRefType}.get` - `VectorType.get` - `FloatAttr.get` closes #60595 depends on D144804, D143830 Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D143869 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 11 +- mlir/lib/Bindings/Python/IRCore.cpp | 122 +++++++++++++++------- mlir/lib/Bindings/Python/IRModule.h | 111 +++++++++++++++----- mlir/lib/Bindings/Python/IRTypes.cpp | 68 +++--------- mlir/lib/Bindings/Python/Pass.cpp | 9 +- mlir/python/mlir/_mlir_libs/__init__.py | 23 +++- 6 files changed, 210 insertions(+), 134 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c8ede8b06..b0c35ffb8 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -344,15 +344,10 @@ class PyFloatAttribute : public PyConcreteAttribute { c.def_static( "get", [](PyType &type, double value, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute attr = mlirFloatAttrDoubleGetChecked(loc, type, value); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(type)).cast() + - "' and expected floating point type."); - } + if (mlirAttributeIsNull(attr)) + throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e03b6470c..8d637ea2b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" //#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" @@ -38,7 +39,7 @@ using llvm::Twine; static const char kContextParseTypeDocstring[] = R"(Parses the assembly form of a type. -Returns a Type object or raises a ValueError if the type cannot be parsed. +Returns a Type object or raises an MLIRError if the type cannot be parsed. See also: https://mlir.llvm.org/docs/LangRef/#type-system )"; @@ -58,7 +59,7 @@ static const char kContextGetNameLocationDocString[] = static const char kModuleParseDocstring[] = R"(Parses a module's assembly format from a string. -Returns a new MlirModule or raises a ValueError if the parsing fails. +Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; @@ -654,6 +655,20 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { return pyHandlerObject; } +MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, + void *userData) { + auto *self = static_cast(userData); + // Check if the context requested we emit errors instead of capturing them. + if (self->ctx->emitErrorDiagnostics) + return mlirLogicalResultFailure(); + + if (mlirDiagnosticGetSeverity(diag) != MlirDiagnosticError) + return mlirLogicalResultFailure(); + + self->errors.emplace_back(PyDiagnostic(diag).getInfo()); + return mlirLogicalResultSuccess(); +} + PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { @@ -870,6 +885,13 @@ py::tuple PyDiagnostic::getNotes() { return *materializedNotes; } +PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { + std::vector notes; + for (py::handle n : getNotes()) + notes.emplace_back(n.cast().getInfo()); + return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; +} + //------------------------------------------------------------------------------ // PyDialect, PyDialectDescriptor, PyDialects, PyDialectRegistry //------------------------------------------------------------------------------ @@ -1062,13 +1084,12 @@ PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, const std::string &sourceStr, const std::string &sourceName) { + PyMlirContext::ErrorCapture errors(contextRef); MlirOperation op = mlirOperationCreateParse(contextRef->get(), toMlirStringRef(sourceStr), toMlirStringRef(sourceName)); - // TODO: Include error diagnostic messages in the exception message if (mlirOperationIsNull(op)) - throw py::value_error( - "Unable to parse operation assembly (see diagnostics)"); + throw MLIRError("Unable to parse operation assembly", errors.take()); return PyOperation::createDetached(std::move(contextRef), op); } @@ -1155,6 +1176,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) { operation.parentKeepAlive = otherOp.parentKeepAlive; } +bool PyOperationBase::verify() { + PyOperation &op = getOperation(); + PyMlirContext::ErrorCapture errors(op.getContext()); + if (!mlirOperationVerify(op.get())) + throw MLIRError("Verification failed", errors.take()); + return true; +} + std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) @@ -2287,6 +2316,16 @@ void mlir::python::populateIRCore(py::module &m) { return self.getMessage(); }); + py::class_(m, "DiagnosticInfo", + py::module_local()) + .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) + .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) + .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) + .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) + .def("__str__", + [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); + py::class_(m, "DiagnosticHandler", py::module_local()) .def("detach", &PyDiagnosticHandler::detach) .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) @@ -2375,6 +2414,11 @@ void mlir::python::populateIRCore(py::module &m) { mlirContextAppendDialectRegistry(self.get(), registry); }, py::arg("registry")) + .def_property("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2566,16 +2610,12 @@ void mlir::python::populateIRCore(py::module &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", - [](const std::string moduleAsm, DefaultingPyMlirContext context) { + [](const std::string &moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirModule module = mlirModuleCreateParse( context->get(), toMlirStringRef(moduleAsm)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirModuleIsNull(module)) { - throw SetPyError( - PyExc_ValueError, - "Unable to parse module assembly (see diagnostics)"); - } + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, py::arg("asm"), py::arg("context") = py::none(), @@ -2724,13 +2764,9 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationGetAsmDocstring) - .def( - "verify", - [](PyOperationBase &self) { - return mlirOperationVerify(self.getOperation()); - }, - "Verify the operation and return true if it passes, false if it " - "fails.") + .def("verify", &PyOperationBase::verify, + "Verify the operation. Raises MLIRError if verification fails, and " + "returns true otherwise.") .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), "Puts self immediately after the other operation in its parent " "block.") @@ -2833,12 +2869,12 @@ void mlir::python::populateIRCore(py::module &m) { // directly. std::string clsOpName = py::cast(cls.attr("OPERATION_NAME")); - MlirStringRef parsedOpName = + MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); - if (!mlirStringRefEqual(parsedOpName, toMlirStringRef(clsOpName))) - throw py::value_error( - "Expected a '" + clsOpName + "' op, got: '" + - std::string(parsedOpName.data, parsedOpName.length) + "'"); + std::string_view parsedOpName(identifier.data, identifier.length); + if (clsOpName != parsedOpName) + throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + + parsedOpName + "'"); return PyOpView::constructDerived(cls, *parsed.get()); }, py::arg("cls"), py::arg("source"), py::kw_only(), @@ -3071,19 +3107,16 @@ void mlir::python::populateIRCore(py::module &m) { .def_static( "parse", [](std::string attrSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirAttribute type = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirAttributeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse attribute: '") + - attrSpec + "'"); - } + if (mlirAttributeIsNull(type)) + throw MLIRError("Unable to parse attribute", errors.take()); return PyAttribute(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), - "Parses an attribute from an assembly form") + "Parses an attribute from an assembly form. Raises an MLIRError on " + "failure.") .def_property_readonly( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, @@ -3182,15 +3215,11 @@ void mlir::python::populateIRCore(py::module &m) { .def_static( "parse", [](std::string typeSpec, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); MlirType type = mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(type)) { - throw SetPyError(PyExc_ValueError, - Twine("Unable to parse type: '") + typeSpec + - "'"); - } + if (mlirTypeIsNull(type)) + throw MLIRError("Unable to parse type", errors.take()); return PyType(context->getRef(), type); }, py::arg("asm"), py::arg("context") = py::none(), @@ -3342,4 +3371,17 @@ void mlir::python::populateIRCore(py::module &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); + + py::register_local_exception_translator([](std::exception_ptr p) { + // We can't define exceptions with custom fields through pybind, so instead + // the exception class is defined in python and imported here. + try { + if (p) + std::rethrow_exception(p); + } catch (const MLIRError &e) { + py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("MLIRError")(e.message, e.errorDiagnostics); + PyErr_SetObject(PyExc_Exception, obj.ptr()); + } + }); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 4aced3639..fc236b1c6 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -221,6 +221,11 @@ class PyMlirContext { /// registration object (internally a PyDiagnosticHandler). pybind11::object attachDiagnosticHandler(pybind11::object callback); + /// Controls whether error diagnostics should be propagated to diagnostic + /// handlers, instead of being captured by `ErrorCapture`. + void setEmitErrorDiagnostics(bool value) { emitErrorDiagnostics = value; } + struct ErrorCapture; + private: PyMlirContext(MlirContext context); // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, @@ -248,6 +253,8 @@ class PyMlirContext { llvm::DenseMap>; LiveOperationMap liveOperations; + bool emitErrorDiagnostics = false; + MlirContext context; friend class PyModule; friend class PyOperation; @@ -281,6 +288,34 @@ class BaseContextObject { PyMlirContextRef contextRef; }; +/// Wrapper around an MlirLocation. +class PyLocation : public BaseContextObject { +public: + PyLocation(PyMlirContextRef contextRef, MlirLocation loc) + : BaseContextObject(std::move(contextRef)), loc(loc) {} + + operator MlirLocation() const { return loc; } + MlirLocation get() const { return loc; } + + /// Enter and exit the context manager. + pybind11::object contextEnter(); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); + + /// Gets a capsule wrapping the void* within the MlirLocation. + pybind11::object getCapsule(); + + /// Creates a PyLocation from the MlirLocation wrapped by a capsule. + /// Note that PyLocation instances are uniqued, so the returned object + /// may be a pre-existing object. Ownership of the underlying MlirLocation + /// is taken by calling this function. + static PyLocation createFromCapsule(pybind11::object capsule); + +private: + MlirLocation loc; +}; + /// Python class mirroring the C MlirDiagnostic struct. Note that these structs /// are only valid for the duration of a diagnostic callback and attempting /// to access them outside of that will raise an exception. This applies to @@ -295,6 +330,16 @@ class PyDiagnostic { pybind11::str getMessage(); pybind11::tuple getNotes(); + /// Materialized diagnostic information. This is safe to access outside the + /// diagnostic callback. + struct DiagnosticInfo { + MlirDiagnosticSeverity severity; + PyLocation location; + std::string message; + std::vector notes; + }; + DiagnosticInfo getInfo(); + private: MlirDiagnostic diagnostic; @@ -351,6 +396,30 @@ class PyDiagnosticHandler { friend class PyMlirContext; }; +/// RAII object that captures any error diagnostics emitted to the provided +/// context. +struct PyMlirContext::ErrorCapture { + ErrorCapture(PyMlirContextRef ctx) + : ctx(ctx), handlerID(mlirContextAttachDiagnosticHandler( + ctx->get(), handler, /*userData=*/this, + /*deleteUserData=*/nullptr)) {} + ~ErrorCapture() { + mlirContextDetachDiagnosticHandler(ctx->get(), handlerID); + assert(errors.empty() && "unhandled captured errors"); + } + + std::vector take() { + return std::move(errors); + }; + +private: + PyMlirContextRef ctx; + MlirDiagnosticHandlerID handlerID; + std::vector errors; + + static MlirLogicalResult handler(MlirDiagnostic diag, void *userData); +}; + /// Wrapper around an MlirDialect. This is exported as `DialectDescriptor` in /// order to differentiate it from the `Dialect` base class which is extended by /// plugins which extend dialect functionality through extension python code. @@ -416,34 +485,6 @@ class PyDialectRegistry { MlirDialectRegistry registry; }; -/// Wrapper around an MlirLocation. -class PyLocation : public BaseContextObject { -public: - PyLocation(PyMlirContextRef contextRef, MlirLocation loc) - : BaseContextObject(std::move(contextRef)), loc(loc) {} - - operator MlirLocation() const { return loc; } - MlirLocation get() const { return loc; } - - /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); - - /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); - - /// Creates a PyLocation from the MlirLocation wrapped by a capsule. - /// Note that PyLocation instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirLocation - /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); - -private: - MlirLocation loc; -}; - /// Used in function arguments when None should resolve to the current context /// manager set instance. class DefaultingPyLocation @@ -519,6 +560,10 @@ class PyOperationBase { void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); + /// Verify the operation. Throws `MLIRError` if verification fails, and + /// returns `true` otherwise. + bool verify(); + /// Each must provide access to the raw Operation. virtual PyOperation &getOperation() = 0; }; @@ -1073,6 +1118,16 @@ class PySymbolTable { MlirSymbolTable symbolTable; }; +/// Custom exception that allows access to error diagnostic information. This is +/// converted to the `ir.MLIRError` python exception when thrown. +struct MLIRError { + MLIRError(llvm::Twine message, + std::vector &&errorDiagnostics = {}) + : message(message.str()), errorDiagnostics(std::move(errorDiagnostics)) {} + std::string message; + std::vector errorDiagnostics; +}; + void populateIRAffine(pybind11::module &m); void populateIRAttributes(pybind11::module &m); void populateIRCore(pybind11::module &m); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 87ffe5936..2166bab90 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -407,17 +407,11 @@ class PyVectorType : public PyConcreteType { "get", [](std::vector shape, PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyVectorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), @@ -438,20 +432,12 @@ class PyRankedTensorType "get", [](std::vector shape, PyType &elementType, std::optional &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirRankedTensorTypeGetChecked( loc, shape.size(), shape.data(), elementType, encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -479,18 +465,10 @@ class PyUnrankedTensorType c.def_static( "get", [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("loc") = py::none(), @@ -511,23 +489,15 @@ class PyMemRefType : public PyConcreteType { [](std::vector shape, PyType &elementType, PyAttribute *layout, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); MlirAttribute memSpaceAttr = memorySpace ? *memorySpace : mlirAttributeGetNull(); MlirType t = mlirMemRefTypeGetChecked(loc, elementType, shape.size(), shape.data(), layoutAttr, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, py::arg("shape"), py::arg("element_type"), @@ -570,23 +540,15 @@ class PyUnrankedMemRefType "get", [](PyType &elementType, PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); MlirAttribute memSpaceAttr = {}; if (memorySpace) memSpaceAttr = *memorySpace; MlirType t = mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - // TODO: Rework error reporting once diagnostic engine is exposed - // in C API. - if (mlirTypeIsNull(t)) { - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point, integer, vector or " - "complex " - "type."); - } + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, py::arg("element_type"), py::arg("memory_space"), diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 7e90d8be6..79c53084e 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -117,15 +117,16 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { .def( "run", [](PyPassManager &passManager, PyOperationBase &op) { + PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_RuntimeError, - "Failure while executing pass pipeline."); + throw MLIRError("Failure while executing pass pipeline", + errors.take()); }, py::arg("operation"), - "Run the pass manager on the provided operation, throw a " - "RuntimeError on failure.") + "Run the pass manager on the provided operation, raising an " + "MLIRError on failure.") .def( "__str__", [](PyPassManager &self) { diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 9ceeef818..7d3d1f6ca 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -100,8 +100,29 @@ def __init__(self, *args, **kwargs): # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() - ir.Context = Context + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ':' + for diag in self.error_diagnostics: + s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') + for note in diag.notes: + s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') + return s + ir.MLIRError = MLIRError + _site_initialize() From aed306cff031f78df3c7f33df72336abf443876e Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 8 Mar 2023 10:41:34 -0800 Subject: [PATCH 430/915] Make it possible to create DenseElementsAttrs with arbitrary shaped types in Python bindings Right now the bindings assume that all DenseElementsAttrs correspond to tensor values, making it impossible to create vector-typed constants. I didn't want to change the API significantly, so I opted for reusing the current signature of `.get`. Its `type` argument now accepts both element types (in which case `shape` and `signless` can be specified too), or a shaped type, which specifies the full type of the created attr (`shape` cannot be specified in that case). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D145053 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index b0c35ffb8..c59a54b66 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -624,8 +624,17 @@ class PyDenseElementsAttribute } } if (bulkLoadElementType) { - auto shapedType = mlirRankedTensorTypeGet( - shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + MlirType shapedType; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + shapedType = *bulkLoadElementType; + } else { + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); + } size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( shapedType, rawBufferSize, arrayInfo.ptr); From 7d05250180d03ffc98dac5dd797980155f4451a7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 20 Mar 2023 02:02:23 -0700 Subject: [PATCH 431/915] [mlir][Linalg][Transform] Avoid FunctionalStyleTransformOpTrait where unnecesseary to improve usability Differential Revision: https://reviews.llvm.org/D146305 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index e2c262ca5..f314496c6 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -325,11 +325,9 @@ def __init__(self, vectorize_padding: Union[bool, BoolAttr] = False, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() if isinstance(vectorize_padding, bool): vectorize_padding = UnitAttr.get() super().__init__( - pdl_operation_type, _get_op_result_or_value(target), vectorize_padding=vectorize_padding, loc=loc, From 7ff3cabb9ebada7a69ce639c5f8f7f1026ca67ec Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 20 Mar 2023 07:06:57 -0700 Subject: [PATCH 432/915] Revert "[mlir][Linalg][Transform] Avoid FunctionalStyleTransformOpTrait where unnecesseary to improve usability" This reverts commit 7d05250180d03ffc98dac5dd797980155f4451a7. This is currently not in a good state as we have some footguns due to missing listeners. --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index f314496c6..e2c262ca5 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -325,9 +325,11 @@ def __init__(self, vectorize_padding: Union[bool, BoolAttr] = False, loc=None, ip=None): + pdl_operation_type = pdl.OperationType.get() if isinstance(vectorize_padding, bool): vectorize_padding = UnitAttr.get() super().__init__( + pdl_operation_type, _get_op_result_or_value(target), vectorize_padding=vectorize_padding, loc=loc, From dfef8e04255a4635c0096bde741c040cb6af3496 Mon Sep 17 00:00:00 2001 From: Sergio Afonso Date: Mon, 13 Mar 2023 12:51:18 +0000 Subject: [PATCH 433/915] [mlir] Support lowering of dialect attributes attached to top-level modules This patch supports the processing of dialect attributes attached to top-level module-type operations during MLIR-to-LLVMIR lowering. This approach modifies the `mlir::translateModuleToLLVMIR()` function to call `ModuleTranslation::convertOperation()` on the top-level operation, after its body has been lowered. This, in turn, will get the `LLVMTranslationDialectInterface` object associated to that operation's dialect before trying to use it for lowering prior to processing dialect attributes attached to the operation. Since there are no `LLVMTranslationDialectInterface`s for the builtin and GPU dialects, which define their own module-type operations, this patch also adds and registers them. The requirement for always calling `mlir::registerBuiltinDialectTranslation()` before any translation of MLIR to LLVM IR where builtin module operations are present is introduced. The purpose of these new translation interfaces is to succeed when processing module-type operations, allowing the lowering process to continue and to prevent the introduction of failures related to not finding such interfaces. Differential Revision: https://reviews.llvm.org/D145932 --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 1 + mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 5 ++++- mlir/lib/CAPI/RegisterEverything/CMakeLists.txt | 1 + mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp | 5 ++++- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 105ce24dd..0be8f2af5 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -8,6 +8,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIExecutionEngine ExecutionEngine.cpp LINK_LIBS PUBLIC + MLIRBuiltinToLLVMIRTranslation MLIRExecutionEngine MLIRLLVMToLLVMIRTranslation ) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index a832119ce..a0ea7f4ab 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -11,6 +11,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/ExecutionEngine/OptUtils.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" @@ -29,7 +30,9 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, }(); (void)initOnce; - mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext()); + auto &ctx = *unwrap(op)->getContext(); + mlir::registerBuiltinDialectTranslation(ctx); + mlir::registerLLVMDialectTranslation(ctx); auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!tmBuilderOrError) { diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 942bba84e..55fe49bce 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything ${translation_libs} ${conversion_libs} + MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp index 25a1a216c..e4a751643 100644 --- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -11,6 +11,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" void mlirRegisterAllDialects(MlirDialectRegistry registry) { @@ -18,7 +19,9 @@ void mlirRegisterAllDialects(MlirDialectRegistry registry) { } void mlirRegisterAllLLVMTranslations(MlirContext context) { - mlir::registerLLVMDialectTranslation(*unwrap(context)); + auto &ctx = *unwrap(context); + mlir::registerBuiltinDialectTranslation(ctx); + mlir::registerLLVMDialectTranslation(ctx); } void mlirRegisterAllPasses() { mlir::registerAllPasses(); } From 364239af646ac7a2e44a21d7e613004156848b64 Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Tue, 21 Mar 2023 08:26:06 -0700 Subject: [PATCH 434/915] Support retrieving the splat value from DenseElementsAttrs in Python This is especially convenient when trying to resize the splat. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D146510 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c59a54b66..40598ecfd 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -777,6 +777,16 @@ class PyDenseElementsAttribute [](PyDenseElementsAttribute &self) -> bool { return mlirDenseElementsAttrIsSplat(self); }) + .def("get_splat_value", + [](PyDenseElementsAttribute &self) -> PyAttribute { + if (!mlirDenseElementsAttrIsSplat(self)) { + throw SetPyError( + PyExc_ValueError, + "get_splat_value called on a non-splat attribute"); + } + return PyAttribute(self.getContext(), + mlirDenseElementsAttrGetSplatValue(self)); + }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } From 0c906a9f9d6bdf369c9d3485f567da80d950c1ca Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Thu, 9 Mar 2023 23:10:57 +0000 Subject: [PATCH 435/915] [APFloat] Add E4M3B11FNUZ X. Sun et al. (https://dl.acm.org/doi/10.5555/3454287.3454728) published a paper showing that an FP format with 4 bits of exponent, 3 bits of significand and an exponent bias of 11 would work quite well for ML applications. Google hardware supports a variant of this format where 0x80 is used to represent NaN, as in the Float8E4M3FNUZ format. Just like the Float8E4M3FNUZ format, this format does not support -0 and values which would map to it will become +0. This format is proposed for inclusion in OpenXLA's StableHLO dialect: https://github.com/openxla/stablehlo/pull/1308 As part of inclusion in that dialect, APFloat needs to know how to handle this format. Differential Revision: https://reviews.llvm.org/D146441 --- mlir/include/mlir-c/BuiltinTypes.h | 7 +++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 18 ++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 ++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 8 ++++++++ 4 files changed, 41 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 8b855d8c3..2b7606f3d 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -95,6 +95,13 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Checks whether the given type is an f8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); + +/// Creates an f8E4M3B11FNUZ type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 2166bab90..6d381b1c0 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -157,6 +157,24 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { } }; +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class PyFloat8E4M3B11FNUZType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); + } +}; + /// Floating Point Type subclass - Float8E5M2FNUZ. class PyFloat8E5M2FNUZType : public PyConcreteType { public: diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index aea122120..2468c0546 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -100,6 +100,14 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { + return unwrap(type).isFloat8E4M3B11FNUZ(); +} + +MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 7d5ff23f6..75b25bd8c 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -53,6 +53,7 @@ __all__ = [ "Float8E4M3FNType", "Float8E5M2Type", "Float8E4M3FNUZType", + "Float8E4M3B11FNUZType", "Float8E5M2FNUZType", "F16Type", "F32Type", @@ -602,6 +603,13 @@ class Float8E4M3FNUZType(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +class Float8E4M3B11FNUZType(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> Float8E4M3B11FNUZType: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + class Float8E5M2FNUZType(Type): def __init__(self, cast_from_type: Type) -> None: ... @staticmethod From 78b77c2d8a853475c2aab1b5383a93aa6daab3de Mon Sep 17 00:00:00 2001 From: David Majnemer Date: Fri, 24 Mar 2023 20:42:47 +0000 Subject: [PATCH 436/915] Fix mlir/lib/Bindings/Python/IRTypes.cpp for Float8E4M3B11FNUZType --- mlir/lib/Bindings/Python/IRTypes.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 6d381b1c0..cb62a402d 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -717,6 +717,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); PyFloat8E4M3FNUZType::bind(m); + PyFloat8E4M3B11FNUZType::bind(m); PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); From 99776490034a802127609973979b797a04c72c9e Mon Sep 17 00:00:00 2001 From: Lang Hames Date: Mon, 27 Mar 2023 19:16:45 -0700 Subject: [PATCH 437/915] [mlir] Update JitRunner, ExecutionEngine after LLVM commit 8b1771bd9f3. LLVM commit 8b1771bd9f3 replaced JITEvaluatedSymbol with ExecutorSymbolDef. --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index a0ea7f4ab..1075ec460 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -104,7 +104,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; symbolMap[interner(unwrap(name))] = - llvm::JITEvaluatedSymbol::fromPointer(sym); + { llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported }; return symbolMap; }); } From c891a71e246240d7c37d0ce72c19319a5b4d9d32 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 28 Mar 2023 11:05:00 -0400 Subject: [PATCH 438/915] [mlir][python] Mark operator== overloads as const This resolves some warnings when building with C++20, e.g.: ``` llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp:545:60: warning: ISO C++20 considers use of overloaded operator '==' (with operand types 'mlir::python::PyAffineExpr' and 'mlir::python::PyAffineExpr') to be ambiguous despite there being a unique best viable function [-Wambiguous-reversed-operator] PyAffineExpr &other) { return self == other; }) ~~~~ ^ ~~~~~ llvm-project/mlir/lib/Bindings/Python/IRAffine.cpp:350:20: note: ambiguity is between a regular call to this operator and a call with the argument order reversed bool PyAffineExpr::operator==(const PyAffineExpr &other) { ^ ``` Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D147018 --- mlir/lib/Bindings/Python/IRAffine.cpp | 6 +++--- mlir/lib/Bindings/Python/IRCore.cpp | 4 ++-- mlir/lib/Bindings/Python/IRModule.h | 10 +++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 09f36b07c..9a2ea6b68 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -347,7 +347,7 @@ class PyAffineCeilDivExpr } // namespace -bool PyAffineExpr::operator==(const PyAffineExpr &other) { +bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); } @@ -406,7 +406,7 @@ class PyAffineMapExprList }; } // namespace -bool PyAffineMap::operator==(const PyAffineMap &other) { +bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); } @@ -483,7 +483,7 @@ class PyIntegerSetConstraintList }; } // namespace -bool PyIntegerSet::operator==(const PyIntegerSet &other) { +bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); } diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8d637ea2b..f2d3780a4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1736,7 +1736,7 @@ void PyInsertionPoint::contextExit(const pybind11::object &excType, // PyAttribute. //------------------------------------------------------------------------------ -bool PyAttribute::operator==(const PyAttribute &other) { +bool PyAttribute::operator==(const PyAttribute &other) const { return mlirAttributeEqual(attr, other.attr); } @@ -1768,7 +1768,7 @@ PyNamedAttribute::PyNamedAttribute(MlirAttribute attr, std::string ownedName) // PyType. //------------------------------------------------------------------------------ -bool PyType::operator==(const PyType &other) { +bool PyType::operator==(const PyType &other) const { return mlirTypeEqual(type, other.type); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fc236b1c6..920e6f467 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -808,7 +808,7 @@ class PyType : public BaseContextObject { public: PyType(PyMlirContextRef contextRef, MlirType type) : BaseContextObject(std::move(contextRef)), type(type) {} - bool operator==(const PyType &other); + bool operator==(const PyType &other) const; operator MlirType() const { return type; } MlirType get() const { return type; } @@ -878,7 +878,7 @@ class PyAttribute : public BaseContextObject { public: PyAttribute(PyMlirContextRef contextRef, MlirAttribute attr) : BaseContextObject(std::move(contextRef)), attr(attr) {} - bool operator==(const PyAttribute &other); + bool operator==(const PyAttribute &other) const; operator MlirAttribute() const { return attr; } MlirAttribute get() const { return attr; } @@ -1003,7 +1003,7 @@ class PyAffineExpr : public BaseContextObject { public: PyAffineExpr(PyMlirContextRef contextRef, MlirAffineExpr affineExpr) : BaseContextObject(std::move(contextRef)), affineExpr(affineExpr) {} - bool operator==(const PyAffineExpr &other); + bool operator==(const PyAffineExpr &other) const; operator MlirAffineExpr() const { return affineExpr; } MlirAffineExpr get() const { return affineExpr; } @@ -1030,7 +1030,7 @@ class PyAffineMap : public BaseContextObject { public: PyAffineMap(PyMlirContextRef contextRef, MlirAffineMap affineMap) : BaseContextObject(std::move(contextRef)), affineMap(affineMap) {} - bool operator==(const PyAffineMap &other); + bool operator==(const PyAffineMap &other) const; operator MlirAffineMap() const { return affineMap; } MlirAffineMap get() const { return affineMap; } @@ -1051,7 +1051,7 @@ class PyIntegerSet : public BaseContextObject { public: PyIntegerSet(PyMlirContextRef contextRef, MlirIntegerSet integerSet) : BaseContextObject(std::move(contextRef)), integerSet(integerSet) {} - bool operator==(const PyIntegerSet &other); + bool operator==(const PyIntegerSet &other) const; operator MlirIntegerSet() const { return integerSet; } MlirIntegerSet get() const { return integerSet; } From 3d0a306d65214cc9da92dd6d2348a38b3ab68901 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 29 Mar 2023 18:49:08 -0400 Subject: [PATCH 439/915] [mlir][python] Support buffer protocol for splat dense attributes These can be made to work by setting the buffer strides to 0. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D147187 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 24 ++++++++++------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 40598ecfd..d252044c8 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -688,13 +688,6 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } py::buffer_info accessBuffer() { - if (mlirDenseElementsAttrIsSplat(*this)) { - // TODO: Currently crashes the program. - // Reported as https://github.com/pybind/pybind11/issues/3336 - throw std::invalid_argument( - "unsupported data type for conversion to Python buffer"); - } - MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -821,15 +814,18 @@ class PyDenseElementsAttribute shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); // Prepare the strides for the buffer_info. SmallVector strides; - intptr_t strideFactor = 1; - for (intptr_t i = 1; i < rank; ++i) { - strideFactor = 1; - for (intptr_t j = i; j < rank; ++j) { - strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + if (mlirDenseElementsAttrIsSplat(*this)) { + // Splats are special, only the single value is stored. + strides.assign(rank, 0); + } else { + for (intptr_t i = 1; i < rank; ++i) { + intptr_t strideFactor = 1; + for (intptr_t j = i; j < rank; ++j) + strideFactor *= mlirShapedTypeGetDimSize(shapedType, j); + strides.push_back(sizeof(Type) * strideFactor); } - strides.push_back(sizeof(Type) * strideFactor); + strides.push_back(sizeof(Type)); } - strides.push_back(sizeof(Type)); std::string format; if (explicitFormat) { format = explicitFormat; From 3efab817e738d14dbc0515b30f0bdfc60eb8ec59 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Thu, 13 Apr 2023 17:05:26 +0200 Subject: [PATCH 440/915] Use `bytes`, not `str`, to return C++ strings to Python. `str` must be valid UTF-8, which is not guaranteed for C++ strings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D147818 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index d252044c8..4a43ffbc0 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -490,9 +490,9 @@ class PyOpaqueAttribute : public PyConcreteAttribute { "data", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return py::str(stringRef.data, stringRef.length); + return py::bytes(stringRef.data, stringRef.length); }, - "Returns the data for the Opaqued attributes as a string"); + "Returns the data for the Opaqued attributes as `bytes`"); } }; @@ -528,6 +528,13 @@ class PyStringAttribute : public PyConcreteAttribute { return py::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); + c.def_property_readonly( + "value_bytes", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::bytes(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute as `bytes`"); } }; From 2d7501bd828a9db92ce7a27160f3d3a331933f83 Mon Sep 17 00:00:00 2001 From: max Date: Fri, 14 Apr 2023 14:20:33 -0500 Subject: [PATCH 441/915] [MLIR][python bindings] implement `PyValue` subclassing to enable operator overloading Differential Revision: https://reviews.llvm.org/D147758 --- .../mlir/Bindings/Python/PybindAdaptors.h | 56 +++++++++++++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 1 + mlir/python/mlir/dialects/python_test.py | 2 +- 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 98d80f010..bec3fc76e 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -453,6 +453,62 @@ class mlir_type_subclass : public pure_subclass { } }; +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Value")) { + } + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(py::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, const py::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in pybind11 due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + py::cpp_function newCf( + [superCls, isaFunction, captureValueName](py::object cls, + py::object otherValue) { + MlirValue rawValue = py::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = py::repr(otherValue).cast(); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + py::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + py::name("__new__"), py::arg("cls"), py::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + py::arg("other_value")); + } +}; + } // namespace adaptors } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f2d3780a4..f3fd38677 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3260,6 +3260,7 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of Value. //---------------------------------------------------------------------------- py::class_(m, "Value", py::module_local()) + .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) .def_property_readonly( diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 9f560c205..5d42ddc47 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest From f678bf825b7590fc49bddccae5d91a3c2b50fd2a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 17 Apr 2023 13:15:45 +0200 Subject: [PATCH 442/915] [mlir] Remove unused using llvm::Twine declaration (NFC). --- mlir/lib/Bindings/Python/IRAttributes.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 4a43ffbc0..5e7138b21 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -21,7 +21,6 @@ using namespace mlir; using namespace mlir::python; using llvm::SmallVector; -using llvm::Twine; //------------------------------------------------------------------------------ // Docstrings (trivial, non-duplicated docstrings are included inline). From 265cec0f83a31e02ded92458669058c245becb7d Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Tue, 18 Apr 2023 21:38:49 +0000 Subject: [PATCH 443/915] [mlir][sparse] introduce a new compressed(hi) dimension level type `compressed(hi)` is similar to `compressed`, but instead of reusing the previous position high as the current position low, it uses a pair of positions for each sparse index. The patch only introduces the definition (syntax) but does not provide codegen implementation. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D148664 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 22 +++++++++++-------- .../Bindings/Python/DialectSparseTensor.cpp | 9 +++++++- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 7d560dd80..8a6763b6c 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,15 +26,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b001_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b010_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b010_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b010_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b010_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b100_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b100_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b100_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b0001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b0010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b0010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b0010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b0010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b0100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b0100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b0100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b0100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b1000_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b1000_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b1000_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b1000_11 }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index e84937df8..0e07f2563 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -26,7 +26,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) - .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO); + .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) + .value("compressed-hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI) + .value("compressed-hi-nu", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU) + .value("compressed-hi-no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO) + .value("compressed-hi-nu-no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) From a6319caabdb7355b3df927a436f9cf3bd091b43c Mon Sep 17 00:00:00 2001 From: max Date: Mon, 24 Apr 2023 10:08:11 -0500 Subject: [PATCH 444/915] [MLIR][python bindings] implement `replace_all_uses_with` on `PyValue` Differential Revision: https://reviews.llvm.org/D148816 --- mlir/include/mlir-c/IR.h | 6 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 23 +++++++++++++++++------ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 84d226b40..b45b95536 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -755,6 +755,12 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); +/// Replace all uses of 'of' value with the 'with' value, updating anything in +/// the IR that uses 'of' to use the other value instead. When this returns +/// there are zero uses of 'of'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, + MlirValue with); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f3fd38677..81c5cd218 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -13,11 +13,9 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" -//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -154,6 +152,11 @@ position in the argument list. If the value is an operation result, this is equivalent to printing the operation that produced it. )"; +static const char kValueReplaceAllUsesWithDocstring[] = + R"(Replace all uses of value with the new value, updating anything in +the IR that uses 'self' to use the other value instead. +)"; + //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3316,10 +3319,18 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); + .def_property_readonly("type", + [](PyValue &self) { + return PyType( + self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }) + .def( + "replace_all_uses_with", + [](PyValue &self, PyValue &with) { + mlirValueReplaceAllUsesOfWith(self.get(), with.get()); + }, + kValueReplaceAllUsesWithDocstring); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 051559acd..0bbcb3083 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -751,6 +751,10 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) { return wrap(opOperand); } +void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { + unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// From c27a4a1b6d7e0cdb47af39828d6e74fa754ed933 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 25 Apr 2023 15:32:14 -0500 Subject: [PATCH 445/915] Revert "[MLIR][python bindings] implement `replace_all_uses_with` on `PyValue`" This reverts commit 3bab7cb089d92cc7025ebc57ef3a74d3ce94ecd8 because it breaks sanitizers. Differential Revision: https://reviews.llvm.org/D149188 --- mlir/include/mlir-c/IR.h | 6 ------ mlir/lib/Bindings/Python/IRCore.cpp | 23 ++++++----------------- mlir/lib/CAPI/IR/IR.cpp | 4 ---- 3 files changed, 6 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b45b95536..84d226b40 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -755,12 +755,6 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); -/// Replace all uses of 'of' value with the 'with' value, updating anything in -/// the IR that uses 'of' to use the other value instead. When this returns -/// there are zero uses of 'of'. -MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, - MlirValue with); - //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 81c5cd218..f3fd38677 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -13,9 +13,11 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" +//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -152,11 +154,6 @@ position in the argument list. If the value is an operation result, this is equivalent to printing the operation that produced it. )"; -static const char kValueReplaceAllUsesWithDocstring[] = - R"(Replace all uses of value with the new value, updating anything in -the IR that uses 'self' to use the other value instead. -)"; - //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3319,18 +3316,10 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", - [](PyValue &self) { - return PyType( - self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }) - .def( - "replace_all_uses_with", - [](PyValue &self, PyValue &with) { - mlirValueReplaceAllUsesOfWith(self.get(), with.get()); - }, - kValueReplaceAllUsesWithDocstring); + .def_property_readonly("type", [](PyValue &self) { + return PyType(self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 0bbcb3083..051559acd 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -751,10 +751,6 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) { return wrap(opOperand); } -void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { - unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); -} - //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// From 5c4ad5c9a932909c5022881c56d89689ae597cf7 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 26 Apr 2023 09:55:27 -0500 Subject: [PATCH 446/915] [MLIR][python bindings] Reimplement `replace_all_uses_with` on `PyValue` Differential Revision: https://reviews.llvm.org/D149261 --- mlir/include/mlir-c/IR.h | 6 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 23 +++++++++++++++++------ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 84d226b40..b45b95536 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -755,6 +755,12 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); +/// Replace all uses of 'of' value with the 'with' value, updating anything in +/// the IR that uses 'of' to use the other value instead. When this returns +/// there are zero uses of 'of'. +MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, + MlirValue with); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f3fd38677..81c5cd218 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -13,11 +13,9 @@ #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" -//#include "mlir-c/Registration.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -154,6 +152,11 @@ position in the argument list. If the value is an operation result, this is equivalent to printing the operation that produced it. )"; +static const char kValueReplaceAllUsesWithDocstring[] = + R"(Replace all uses of value with the new value, updating anything in +the IR that uses 'self' to use the other value instead. +)"; + //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3316,10 +3319,18 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, kValueDunderStrDocstring) - .def_property_readonly("type", [](PyValue &self) { - return PyType(self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }); + .def_property_readonly("type", + [](PyValue &self) { + return PyType( + self.getParentOperation()->getContext(), + mlirValueGetType(self.get())); + }) + .def( + "replace_all_uses_with", + [](PyValue &self, PyValue &with) { + mlirValueReplaceAllUsesOfWith(self.get(), with.get()); + }, + kValueReplaceAllUsesWithDocstring); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 051559acd..0bbcb3083 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -751,6 +751,10 @@ MlirOpOperand mlirValueGetFirstUse(MlirValue value) { return wrap(opOperand); } +void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { + unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// From ad382a3825be0170c74c741ef2b17fa9d89a8531 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 26 Apr 2023 15:27:07 -0500 Subject: [PATCH 447/915] [MLIR][python bindings] Add some AttrBuilder and port _exts to use them. Differential Revision: https://reviews.llvm.org/D149287 --- .../mlir/dialects/_loop_transform_ops_ext.py | 121 +++---- mlir/python/mlir/dialects/_pdl_ops_ext.py | 277 ++++++++-------- .../dialects/_structured_transform_ops_ext.py | 309 +++++++++--------- .../mlir/dialects/_transform_ops_ext.py | 148 +++++---- mlir/python/mlir/ir.py | 57 +++- 5 files changed, 488 insertions(+), 424 deletions(-) diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index 0dc8fc074..a275ea615 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -11,65 +11,63 @@ from typing import Optional, Union -def _get_int64_attr(arg: Optional[Union[int, IntegerAttr]], - default_value: int = None): - if isinstance(arg, IntegerAttr): - return arg - - if arg is None: - assert default_value is not None, "must provide default value" - arg = default_value - - return IntegerAttr.get(IntegerType.get_signless(64), arg) - - class GetParentForOp: """Extension for GetParentForOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: int = 1, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 super().__init__( result_type, _get_op_result_or_value(target), - num_loops=_get_int64_attr(num_loops, default_value=1), + num_loops=num_loops, ip=ip, - loc=loc) + loc=loc, + ) class LoopOutlineOp: """Extension for LoopOutlineOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): super().__init__( result_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), ip=ip, - loc=loc) + loc=loc, + ) class LoopPeelOp: """Extension for LoopPeelOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): super().__init__( result_type, _get_op_result_or_value(target), @@ -77,40 +75,51 @@ def __init__(self, fail_if_already_divisible, BoolAttr) else BoolAttr.get(fail_if_already_divisible)), ip=ip, - loc=loc) + loc=loc, + ) class LoopPipelineOp: """Extension for LoopPipelineOp.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 super().__init__( result_type, _get_op_result_or_value(target), - iteration_interval=_get_int64_attr(iteration_interval, default_value=1), - read_latency=_get_int64_attr(read_latency, default_value=10), + iteration_interval=iteration_interval, + read_latency=read_latency, ip=ip, - loc=loc) + loc=loc, + ) class LoopUnrollOp: """Extension for LoopUnrollOp.""" - def __init__(self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None): + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): super().__init__( _get_op_result_or_value(target), - factor=_get_int64_attr(factor), + factor=factor, ip=ip, - loc=loc) + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index 428301b18..40ccbef63 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -8,61 +8,26 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Union, Optional, Sequence, List, Mapping -from ._ods_common import get_op_result_or_value as _get_value, get_op_results_or_values as _get_values - - -def _get_int_attr(bits: int, value: Union[IntegerAttr, int]) -> IntegerAttr: - """Converts the given value to signless integer attribute of given bit width.""" - if isinstance(value, int): - ty = IntegerType.get_signless(bits) - return IntegerAttr.get(ty, value) - else: - return value - - -def _get_array_attr(attrs: Union[ArrayAttr, Sequence[Attribute]]) -> ArrayAttr: - """Converts the given value to array attribute.""" - if isinstance(attrs, ArrayAttr): - return attrs - else: - return ArrayAttr.get(list(attrs)) - - -def _get_str_array_attr(attrs: Union[ArrayAttr, Sequence[str]]) -> ArrayAttr: - """Converts the given value to string array attribute.""" - if isinstance(attrs, ArrayAttr): - return attrs - else: - return ArrayAttr.get([StringAttr.get(s) for s in attrs]) - - -def _get_str_attr(name: Union[StringAttr, str]) -> Optional[StringAttr]: - """Converts the given value to string attribute.""" - if isinstance(name, str): - return StringAttr.get(name) - else: - return name - - -def _get_type_attr(type: Union[TypeAttr, Type]) -> TypeAttr: - """Converts the given value to type attribute.""" - if isinstance(type, Type): - return TypeAttr.get(type) - else: - return type +from typing import Union, Optional, Sequence, Mapping +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, +) class ApplyNativeConstraintOp: """Specialization for PDL apply native constraint op class.""" - def __init__(self, - name: Union[str, StringAttr], - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = _get_str_attr(name) + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] args = _get_values(args) super().__init__(name, args, loc=loc, ip=ip) @@ -70,14 +35,17 @@ def __init__(self, class ApplyNativeRewriteOp: """Specialization for PDL apply native rewrite op class.""" - def __init__(self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = _get_str_attr(name) + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] args = _get_values(args) super().__init__(results, name, args, loc=loc, ip=ip) @@ -85,12 +53,14 @@ def __init__(self, class AttributeOp: """Specialization for PDL attribute op class.""" - def __init__(self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None): + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): valueType = valueType if valueType is None else _get_value(valueType) result = pdl.AttributeType.get() super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) @@ -99,11 +69,13 @@ def __init__(self, class EraseOp: """Specialization for PDL erase op class.""" - def __init__(self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): operation = _get_value(operation) super().__init__(operation, loc=loc, ip=ip) @@ -111,11 +83,13 @@ def __init__(self, class OperandOp: """Specialization for PDL operand op class.""" - def __init__(self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): type = type if type is None else _get_value(type) result = pdl.ValueType.get() super().__init__(result, valueType=type, loc=loc, ip=ip) @@ -124,11 +98,13 @@ def __init__(self, class OperandsOp: """Specialization for PDL operands op class.""" - def __init__(self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): types = types if types is None else _get_value(types) result = pdl.RangeType.get(pdl.ValueType.get()) super().__init__(result, valueType=types, loc=loc, ip=ip) @@ -137,15 +113,23 @@ def __init__(self, class OperationOp: """Specialization for PDL operand op class.""" - def __init__(self, - name: Optional[Union[str, StringAttr]] = None, - args: Sequence[Union[OpView, Operation, Value]] = [], - attributes: Mapping[str, Union[OpView, Operation, Value]] = {}, - types: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): - name = name if name is None else _get_str_attr(name) + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, + Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] args = _get_values(args) attrNames = [] attrValues = [] @@ -155,22 +139,29 @@ def __init__(self, attrNames = ArrayAttr.get(attrNames) types = _get_values(types) result = pdl.OperationType.get() - super().__init__(result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip) + super().__init__(result, + args, + attrValues, + attrNames, + types, + opName=name, + loc=loc, + ip=ip) class PatternOp: """Specialization for PDL pattern op class.""" - def __init__(self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): """Creates an PDL `pattern` operation.""" - name_attr = None if name is None else _get_str_attr(name) - benefit_attr = _get_int_attr(16, benefit) - super().__init__(benefit_attr, sym_name=name_attr, loc=loc, ip=ip) + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) self.regions[0].blocks.append() @property @@ -182,13 +173,17 @@ def body(self): class ReplaceOp: """Specialization for PDL replace op class.""" - def __init__(self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Sequence[Union[OpView, Operation, Value]] = [], - loc=None, - ip=None): + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] op = _get_value(op) with_op = with_op if with_op is None else _get_value(with_op) with_values = _get_values(with_values) @@ -198,13 +193,14 @@ def __init__(self, class ResultOp: """Specialization for PDL result op class.""" - def __init__(self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None): - index = _get_int_attr(32, index) + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): parent = _get_value(parent) result = pdl.ValueType.get() super().__init__(result, parent, index, loc=loc, ip=ip) @@ -213,32 +209,36 @@ def __init__(self, class ResultsOp: """Specialization for PDL results op class.""" - def __init__(self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None): + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): parent = _get_value(parent) - index = index if index is None else _get_int_attr(32, index) super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: """Specialization for PDL rewrite op class.""" - def __init__(self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Sequence[Union[OpView, Operation, Value]] = [], - *, - loc=None, - ip=None): + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] root = root if root is None else _get_value(root) - name = name if name is None else _get_str_attr(name) args = _get_values(args) - super().__init__(args, root=root,name=name, loc=loc, ip=ip) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) def add_body(self): """Add body (block) to the rewrite.""" @@ -259,8 +259,6 @@ def __init__(self, *, loc=None, ip=None): - constantType = constantType if constantType is None else _get_type_attr( - constantType) result = pdl.TypeType.get() super().__init__(result, constantType=constantType, loc=loc, ip=ip) @@ -268,13 +266,14 @@ def __init__(self, class TypesOp: """Specialization for PDL types op class.""" - def __init__(self, - constantTypes: Sequence[Union[TypeAttr, Type]] = [], - *, - loc=None, - ip=None): - constantTypes = _get_array_attr( - [_get_type_attr(ty) for ty in constantTypes]) - constantTypes = None if not constantTypes else constantTypes + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] result = pdl.RangeType.get(pdl.TypeType.get()) super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index e2c262ca5..9c051cd3d 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -15,180 +15,159 @@ OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] -def _get_int64_attr(value: Union[int, Attribute]) -> IntegerAttr: - if isinstance(value, int): - return IntegerAttr.get(IntegerType.get_signless(64), value) - return value - - -def _get_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Attribute]]]) -> ArrayAttr: - """Creates an array attribute from its operand.""" - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - - return ArrayAttr.get(values) - - -def _get_int_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[IntegerAttr, int]]]] -) -> ArrayAttr: - """Creates an integer array attribute from its operand. - - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, possibly intersperced, to - create a new array attribute containing integer attributes. Expects the - thread-local MLIR context to have been set by the context manager. - """ - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - - return ArrayAttr.get([_get_int64_attr(v) for v in values]) - -def _get_dense_int64_array_attr( - values: Sequence[int]) -> DenseI64ArrayAttr: - """Creates a dense integer array from a sequence of integers. - Expects the thread-local MLIR context to have been set by the context - manager. - """ - if values is None: - return DenseI64ArrayAttr.get([]) - return DenseI64ArrayAttr.get(values) - def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: """Creates an array attribute containing array attributes of integers. - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, potentially interpserced, to - create a new array-of-array attribute. Expects the thread-local MLIR context - to have been set by the context manager. - """ + If the operand is already an array attribute, forwards it. Otherwise treats + the operand as a list of attributes or integers, potentially interpserced, to + create a new array-of-array attribute. Expects the thread-local MLIR context + to have been set by the context manager. + """ if values is None: return ArrayAttr.get([]) if isinstance(values, ArrayAttr): return values + if isinstance(values, list): + values = [ + ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) + for v in value]) + for value in values + ] - return ArrayAttr.get([_get_int_array_attr(value) for value in values]) + return ArrayAttr.get(values) class DecomposeOp: """Specialization for DecomposeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + super().__init__(pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) class GeneralizeOp: """Specialization for GeneralizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + super().__init__(pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) class InterchangeOp: """Specialization for InterchangeOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() - interchange_attr = _get_dense_int64_array_attr(iterator_interchange) super().__init__( pdl_operation_type, _get_op_result_or_value(target), - iterator_interchange=interchange_attr, + iterator_interchange=iterator_interchange, loc=loc, - ip=ip) + ip=ip, + ) class MatchOp: """Specialization for MatchOp class.""" @classmethod - def match_op_names(MatchOp, - target: Union[Operation, Value], - names: Sequence[str], - loc=None, - ip=None): + def match_op_names( + MatchOp, + target: Union[Operation, Value], + names: Sequence[str], + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() return MatchOp( pdl_operation_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, - ip=ip) + ip=ip, + ) class MultiTileSizesOp: """Specialization for MultitileSizesOp class.""" - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Union[int, IntegerAttr]] = None, - loc=None, - ip=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + if divisor is None: + divisor = 1 super().__init__( result_type, result_type, result_type, _get_op_result_or_value(target), - dimension=_get_int64_attr(dimension), - target_size=_get_int64_attr(target_size), - divisor=_get_int64_attr(divisor if divisor else 1), + dimension=dimension, + target_size=target_size, + divisor=divisor, loc=loc, - ip=ip) + ip=ip, + ) class PadOp: """Specialization for PadOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - padding_values: Optional[Union[ArrayAttr, - Sequence[Attribute]]] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ - ArrayAttr, IntOrAttrList]]]] = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + padding_values: Optional[Optional[Union[ArrayAttr, + Sequence[Attribute]]]] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ + ArrayAttr, IntOrAttrList]]]] = None, + loc=None, + ip=None, + ): + if transpose_paddings is None: + transpose_paddings = [] + if pack_paddings is None: + pack_paddings = [] + if padding_dimensions is None: + padding_dimensions = [] + if padding_values is None: + padding_values = [] pdl_operation_type = pdl.OperationType.get() - padding_values_attr = _get_array_attr(padding_values) - padding_dimensions_attr = _get_int_array_attr(padding_dimensions) - pack_paddings_attr = _get_int_array_attr(pack_paddings) transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) super().__init__( pdl_operation_type, _get_op_result_or_value(target), - padding_values=padding_values_attr, - padding_dimensions=padding_dimensions_attr, - pack_paddings=pack_paddings_attr, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pack_paddings=pack_paddings, transpose_paddings=transpose_paddings_attr, loc=loc, - ip=ip) + ip=ip, + ) class ScalarizeOp: @@ -196,29 +175,29 @@ class ScalarizeOp: def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip) + super().__init__(pdl_operation_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) class SplitOp: """Specialization for SplitOp class.""" - def __init__(self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None): - dimension = _get_int64_attr(dimension) + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): if isinstance(split_point, int): - split_point = _get_int64_attr(split_point) - - if isinstance(split_point, Attribute): static_split_point = split_point dynamic_split_point = None else: - static_split_point = _get_int64_attr(ShapedType.get_dynamic_size()) + static_split_point = ShapedType.get_dynamic_size() dynamic_split_point = _get_op_result_or_value(split_point) target = _get_op_result_or_value(target) @@ -231,44 +210,53 @@ def __init__(self, static_split_point=static_split_point, dynamic_split_point=dynamic_split_point, loc=loc, - ip=ip) + ip=ip, + ) class TileOp: """Specialization for TileOp class.""" @overload - def __init__(self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): ... @overload - def __init__(self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): ... - def __init__(self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, - Value]], ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None): + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], + ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + if interchange is None: + interchange = [] if sizes is None: sizes = [] @@ -293,8 +281,8 @@ def __init__(self, target = loop_types_or_target assert target_or_none is None, "Cannot construct TileOp with two targets." else: - loop_types = ([loop_types_or_target] * num_loops) if isinstance( - loop_types_or_target, Type) else loop_types_or_target + loop_types = (([loop_types_or_target] * num_loops) if isinstance( + loop_types_or_target, Type) else loop_types_or_target) target = target_or_none target = _get_op_result_or_value(target) @@ -305,10 +293,10 @@ def __init__(self, target, dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, - interchange=_get_dense_int64_array_attr(interchange) - if interchange else None, + interchange=interchange, loc=loc, - ip=ip) + ip=ip, + ) def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: if not attr: @@ -319,12 +307,14 @@ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: class VectorizeOp: """Specialization for VectorizeOp class.""" - def __init__(self, - target: Union[Operation, Value], - *, - vectorize_padding: Union[bool, BoolAttr] = False, - loc=None, - ip=None): + def __init__( + self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None, + ): pdl_operation_type = pdl.OperationType.get() if isinstance(vectorize_padding, bool): vectorize_padding = UnitAttr.get() @@ -333,4 +323,5 @@ def __init__(self, _get_op_result_or_value(target), vectorize_padding=vectorize_padding, loc=loc, - ip=ip) + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 593b8855c..8651c76ea 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -4,102 +4,119 @@ try: from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from argparse import SUPPRESS -from typing import Optional, overload, Sequence, Union - - -def _get_symbol_ref_attr(value: Union[Attribute, str]): - if isinstance(value, Attribute): - return value - return FlatSymbolRefAttr.get(value) +from typing import Optional, Sequence, Union class CastOp: - def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None): + super().__init__(result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) class GetClosestIsolatedParentOp: - def __init__(self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) - - -class MergeHandlesOp: - def __init__(self, - handles: Sequence[Union[Operation, Value]], + result_type: Type, + target: Union[Operation, Value], *, - deduplicate: bool = False, loc=None, ip=None): + super().__init__(result_type, + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class MergeHandlesOp: + + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): super().__init__( [_get_op_result_or_value(h) for h in handles], deduplicate=deduplicate, loc=loc, - ip=ip) + ip=ip, + ) class PDLMatchOp: - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): super().__init__( result_type, _get_op_result_or_value(target), - _get_symbol_ref_attr(pattern_name), + pattern_name, loc=loc, - ip=ip) + ip=ip, + ) class ReplicateOp: - def __init__(self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None): + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): super().__init__( [_get_op_result_or_value(h).type for h in handles], _get_op_result_or_value(pattern), [_get_op_result_or_value(h) for h in handles], loc=loc, - ip=ip) + ip=ip, + ) class SequenceOp: - def __init__(self, failure_propagation_mode, results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], - Operation, OpView]] = None): - root = _get_op_result_or_value(target) if isinstance( - target, (Operation, Value)) else None + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation, + OpView]] = None, + ): + root = (_get_op_result_or_value(target) if isinstance( + target, (Operation, Value)) else None) root_type = root.type if not isinstance(target, Type) else target if not isinstance(failure_propagation_mode, Attribute): failure_propagation_mode_attr = IntegerAttr.get( IntegerType.get_signless(32), failure_propagation_mode._as_int()) else: - failure_propagation_mode = failure_propagation_mode + failure_propagation_mode_attr = failure_propagation_mode if extra_bindings is None: extra_bindings = [] @@ -114,10 +131,12 @@ def __init__(self, failure_propagation_mode, results: Sequence[Type], else: extra_binding_types = [v.type for v in extra_bindings] - super().__init__(results_=results, - failure_propagation_mode=failure_propagation_mode_attr, - root=root, - extra_bindings=extra_bindings) + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode_attr, + root=root, + extra_bindings=extra_bindings, + ) self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) @property @@ -143,10 +162,7 @@ def __init__(self, root = _get_op_result_or_value(target) if not isinstance(target, Type) else None root_type = target if isinstance(target, Type) else root.type - super().__init__( - root=root, - loc=loc, - ip=ip) + super().__init__(root=root, loc=loc, ip=ip) self.regions[0].blocks.append(root_type) @property @@ -160,9 +176,13 @@ def bodyTarget(self) -> Value: class YieldOp: - def __init__(self, - operands: Union[Operation, Sequence[Value]] = [], - *, - loc=None, - ip=None): + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 1e24fcbf9..714253426 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,9 +8,11 @@ # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind): + def decorator_builder(func): AttrBuilder.insert(kind, func) return func + return decorator_builder @@ -18,34 +20,77 @@ def decorator_builder(func): def _boolAttr(x, context): return BoolAttr.get(x, context=context) + @register_attribute_builder("IndexAttr") def _indexAttr(x, context): return IntegerAttr.get(IndexType.get(context=context), x) + +@register_attribute_builder("I16Attr") +def _i32Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) + + @register_attribute_builder("I32Attr") def _i32Attr(x, context): - return IntegerAttr.get( - IntegerType.get_signless(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) + @register_attribute_builder("I64Attr") def _i64Attr(x, context): - return IntegerAttr.get( - IntegerType.get_signless(64, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) + @register_attribute_builder("StrAttr") def _stringAttr(x, context): return StringAttr.get(x, context=context) + @register_attribute_builder("SymbolNameAttr") def _symbolNameAttr(x, context): return StringAttr.get(x, context=context) + +@register_attribute_builder("SymbolRefAttr") +def _symbolRefAttr(x, context): + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("ArrayAttr") +def _arrayAttr(x, context): + return ArrayAttr.get(x, context=context) + + +@register_attribute_builder("I64ArrayAttr") +def _i64ArrayAttr(x, context): + return ArrayAttr.get([_i64Attr(v, context) for v in x]) + + +@register_attribute_builder("DenseI64ArrayAttr") +def _denseI64ArrayAttr(x, context): + return DenseI64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("TypeAttr") +def _typeAttr(x, context): + return TypeAttr.get(x, context=context) + + +@register_attribute_builder("TypeArrayAttr") +def _typeArrayAttr(x, context): + return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) + + try: import numpy as np + @register_attribute_builder("IndexElementsAttr") def _indexElementsAttr(x, context): return DenseElementsAttr.get( - np.array(x, dtype=np.int64), type=IndexType.get(context=context), - context=context) + np.array(x, dtype=np.int64), + type=IndexType.get(context=context), + context=context, + ) + except ImportError: pass From b23da019742c9b4b170ec59a129a46d2b36ff627 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sat, 29 Apr 2023 05:35:53 -0700 Subject: [PATCH 448/915] [mlir][bytecode] Allow client to specify a desired version. Add method to set a desired bytecode file format to generate. Change write method to be able to return status including the minimum bytecode version needed by reader. This enables generating an older version of the bytecode (not dialect ops, attributes or types). But this does not guarantee that an older version can always be generated, e.g., if a dialect uses a new encoding only available at later bytecode version. This clamps setting to at most current version. Differential Revision: https://reviews.llvm.org/D146555 --- mlir/include/mlir-c/IR.h | 44 ++++++++++++++++++++++++++--- mlir/include/mlir/CAPI/IR.h | 2 ++ mlir/lib/Bindings/Python/IRCore.cpp | 23 +++++++++++++-- mlir/lib/Bindings/Python/IRModule.h | 4 ++- mlir/lib/CAPI/IR/IR.cpp | 38 +++++++++++++++++++++++-- 5 files changed, 100 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b45b95536..315ec3a84 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -48,6 +48,7 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void); DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); DEFINE_C_API_STRUCT(MlirDialectRegistry, void); @@ -408,6 +409,24 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +/// Creates new printing flags with defaults, intended for customization. +/// Must be freed with a call to mlirBytecodeWriterConfigDestroy(). +MLIR_CAPI_EXPORTED MlirBytecodeWriterConfig +mlirBytecodeWriterConfigCreate(void); + +/// Destroys printing flags created with mlirBytecodeWriterConfigCreate. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config); + +/// Sets the version to emit in the writer config. +MLIR_CAPI_EXPORTED void +mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version); + //===----------------------------------------------------------------------===// // Operation API. //===----------------------------------------------------------------------===// @@ -546,10 +565,27 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); -/// Same as mlirOperationPrint but writing the bytecode format out. -MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, - MlirStringCallback callback, - void *userData); +struct MlirBytecodeWriterResult { + int64_t minVersion; +}; +typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult; + +inline static bool +mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) { + return res.minVersion; +} + +/// Same as mlirOperationPrint but writing the bytecode format and returns the +/// minimum bytecode version the consumer needs to support. +MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode( + MlirOperation op, MlirStringCallback callback, void *userData); + +/// Same as mlirOperationWriteBytecode but with writer config. +MLIR_CAPI_EXPORTED MlirBytecodeWriterResult +mlirOperationWriteBytecodeWithConfig(MlirOperation op, + MlirBytecodeWriterConfig config, + MlirStringCallback callback, + void *userData); /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index 2f32c76e1..b8ccec896 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -15,11 +15,13 @@ #ifndef MLIR_CAPI_IR_H #define MLIR_CAPI_IR_H +#include "mlir/Bytecode/BytecodeWriter.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) DEFINE_C_API_PTR_METHODS(MlirDialectRegistry, mlir::DialectRegistry) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 81c5cd218..052998be1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -124,6 +124,9 @@ static const char kOperationPrintBytecodeDocstring[] = Args: file: The file like object to write to. + desired_version: The version of bytecode to emit. +Returns: + The bytecode writer status. )"; static const char kOperationStrDunderDocstring[] = @@ -1131,12 +1134,21 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::writeBytecode(const py::object &fileObject) { +MlirBytecodeWriterResult +PyOperationBase::writeBytecode(const py::object &fileObject, + std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); PyFileAccumulator accum(fileObject, /*binary=*/true); - mlirOperationWriteBytecode(operation, accum.getCallback(), - accum.getUserData()); + + if (!bytecodeVersion.has_value()) + return mlirOperationWriteBytecode(operation, accum.getCallback(), + accum.getUserData()); + + MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); + mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); + return mlirOperationWriteBytecodeWithConfig( + operation, config, accum.getCallback(), accum.getUserData()); } py::object PyOperationBase::getAsm(bool binary, @@ -2757,6 +2769,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("use_local_scope") = false, py::arg("assume_verified") = false, kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. @@ -3365,6 +3378,10 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("from_op"), py::arg("all_sym_uses_visible"), py::arg("callback")); + py::class_(m, "BytecodeResult", py::module_local()) + .def("min_version", + [](MlirBytecodeWriterResult &res) { return res.minVersion; }); + // Container bindings. PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 920e6f467..56bb834b4 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -554,7 +554,9 @@ class PyOperationBase { bool assumeVerified); // Implement the bound 'writeBytecode' method. - void writeBytecode(const pybind11::object &fileObject); + MlirBytecodeWriterResult + writeBytecode(const pybind11::object &fileObject, + std::optional bytecodeVersion); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 0bbcb3083..03f154965 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -145,6 +145,23 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { unwrap(flags)->assumeVerified(); } +//===----------------------------------------------------------------------===// +// Bytecode printing flags API. +//===----------------------------------------------------------------------===// + +MlirBytecodeWriterConfig mlirBytecodeWriterConfigCreate() { + return wrap(new BytecodeWriterConfig()); +} + +void mlirBytecodeWriterConfigDestroy(MlirBytecodeWriterConfig config) { + delete unwrap(config); +} + +void mlirBytecodeWriterConfigDesiredEmitVersion(MlirBytecodeWriterConfig flags, + int64_t version) { + unwrap(flags)->setDesiredBytecodeVersion(version); +} + //===----------------------------------------------------------------------===// // Location API. //===----------------------------------------------------------------------===// @@ -507,10 +524,25 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } -void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, - void *userData) { +MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData) { + detail::CallbackOstream stream(callback, userData); + MlirBytecodeWriterResult res; + BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream); + res.minVersion = r.minVersion; + return res; +} + +MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig( + MlirOperation op, MlirBytecodeWriterConfig config, + MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); - writeBytecodeToFile(unwrap(op), stream); + BytecodeWriterResult r = + writeBytecodeToFile(unwrap(op), stream, *unwrap(config)); + MlirBytecodeWriterResult res; + res.minVersion = r.minVersion; + return res; } void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } From 2c8dedfc5e366d04ead2dd77da0de8599e14493e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 30 Apr 2023 22:11:02 -0700 Subject: [PATCH 449/915] [mlir][bytecode] Return error instead of min version Can't return a well-formed IR output while enabling version to be bumped up during emission. Previously it would return min version but potentially invalid IR which was confusing, instead make it return error and abort immediately instead. Differential Revision: https://reviews.llvm.org/D149569 --- mlir/include/mlir-c/IR.h | 31 ++++++++++------------------- mlir/lib/Bindings/Python/IRCore.cpp | 16 +++++++-------- mlir/lib/Bindings/Python/IRModule.h | 5 ++--- mlir/lib/CAPI/IR/IR.cpp | 19 ++++++------------ 4 files changed, 26 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 315ec3a84..90af14461 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -565,27 +565,16 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); -struct MlirBytecodeWriterResult { - int64_t minVersion; -}; -typedef struct MlirBytecodeWriterResult MlirBytecodeWriterResult; - -inline static bool -mlirBytecodeWriterResultGetMinVersion(MlirBytecodeWriterResult res) { - return res.minVersion; -} - -/// Same as mlirOperationPrint but writing the bytecode format and returns the -/// minimum bytecode version the consumer needs to support. -MLIR_CAPI_EXPORTED MlirBytecodeWriterResult mlirOperationWriteBytecode( - MlirOperation op, MlirStringCallback callback, void *userData); - -/// Same as mlirOperationWriteBytecode but with writer config. -MLIR_CAPI_EXPORTED MlirBytecodeWriterResult -mlirOperationWriteBytecodeWithConfig(MlirOperation op, - MlirBytecodeWriterConfig config, - MlirStringCallback callback, - void *userData); +/// Same as mlirOperationPrint but writing the bytecode format. +MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, + MlirStringCallback callback, + void *userData); + +/// Same as mlirOperationWriteBytecode but with writer config and returns +/// failure only if desired bytecode could not be honored. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirOperationWriteBytecodeWithConfig( + MlirOperation op, MlirBytecodeWriterConfig config, + MlirStringCallback callback, void *userData); /// Prints an operation to stderr. MLIR_CAPI_EXPORTED void mlirOperationDump(MlirOperation op); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 052998be1..f2e188e78 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -16,6 +16,7 @@ #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -1134,9 +1135,8 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } -MlirBytecodeWriterResult -PyOperationBase::writeBytecode(const py::object &fileObject, - std::optional bytecodeVersion) { +void PyOperationBase::writeBytecode(const py::object &fileObject, + std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); PyFileAccumulator accum(fileObject, /*binary=*/true); @@ -1147,8 +1147,12 @@ PyOperationBase::writeBytecode(const py::object &fileObject, MlirBytecodeWriterConfig config = mlirBytecodeWriterConfigCreate(); mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); - return mlirOperationWriteBytecodeWithConfig( + MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( operation, config, accum.getCallback(), accum.getUserData()); + if (mlirLogicalResultIsFailure(res)) + throw py::value_error((Twine("Unable to honor desired bytecode version ") + + Twine(*bytecodeVersion)) + .str()); } py::object PyOperationBase::getAsm(bool binary, @@ -3378,10 +3382,6 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("from_op"), py::arg("all_sym_uses_visible"), py::arg("callback")); - py::class_(m, "BytecodeResult", py::module_local()) - .def("min_version", - [](MlirBytecodeWriterResult &res) { return res.minVersion; }); - // Container bindings. PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 56bb834b4..ade790ba0 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -554,9 +554,8 @@ class PyOperationBase { bool assumeVerified); // Implement the bound 'writeBytecode' method. - MlirBytecodeWriterResult - writeBytecode(const pybind11::object &fileObject, - std::optional bytecodeVersion); + void writeBytecode(const pybind11::object &fileObject, + std::optional bytecodeVersion); /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 03f154965..0069bf102 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -524,25 +524,18 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } -MlirBytecodeWriterResult mlirOperationWriteBytecode(MlirOperation op, - MlirStringCallback callback, - void *userData) { +void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, + void *userData) { detail::CallbackOstream stream(callback, userData); - MlirBytecodeWriterResult res; - BytecodeWriterResult r = writeBytecodeToFile(unwrap(op), stream); - res.minVersion = r.minVersion; - return res; + // As no desired version is set, no failure can occur. + (void)writeBytecodeToFile(unwrap(op), stream); } -MlirBytecodeWriterResult mlirOperationWriteBytecodeWithConfig( +MlirLogicalResult mlirOperationWriteBytecodeWithConfig( MlirOperation op, MlirBytecodeWriterConfig config, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); - BytecodeWriterResult r = - writeBytecodeToFile(unwrap(op), stream, *unwrap(config)); - MlirBytecodeWriterResult res; - res.minVersion = r.minVersion; - return res; + return wrap(writeBytecodeToFile(unwrap(op), stream, *unwrap(config))); } void mlirOperationDump(MlirOperation op) { return unwrap(op)->dump(); } From 3ff828a652a7cfb8eeb0dc505b3948300eed40fe Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 26 Feb 2023 10:46:01 -0500 Subject: [PATCH 450/915] Introduce MLIR Op Properties This new features enabled to dedicate custom storage inline within operations. This storage can be used as an alternative to attributes to store data that is specific to an operation. Attribute can also be stored inside the properties storage if desired, but any kind of data can be present as well. This offers a way to store and mutate data without uniquing in the Context like Attribute. See the OpPropertiesTest.cpp for an example where a struct with a std::vector<> is attached to an operation and mutated in-place: struct TestProperties { int a = -1; float b = -1.; std::vector array = {-33}; }; More complex scheme (including reference-counting) are also possible. The only constraint to enable storing a C++ object as "properties" on an operation is to implement three functions: - convert from the candidate object to an Attribute - convert from the Attribute to the candidate object - hash the object Optional the parsing and printing can also be customized with 2 extra functions. A new options is introduced to ODS to allow dialects to specify: let usePropertiesForAttributes = 1; When set to true, the inherent attributes for all the ops in this dialect will be using properties instead of being stored alongside discardable attributes. The TestDialect showcases this feature. Another change is that we introduce new APIs on the Operation class to access separately the inherent attributes from the discardable ones. We envision deprecating and removing the `getAttr()`, `getAttrsDictionary()`, and other similar method which don't make the distinction explicit, leading to an entirely separate namespace for discardable attributes. Differential Revision: https://reviews.llvm.org/D141742 --- mlir/lib/CAPI/IR/IR.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 0069bf102..6ed32e1ce 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -340,7 +340,8 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (succeeded(inferInterface->inferReturnTypes( context, state.location, state.operands, - state.attributes.getDictionary(context), state.regions, state.types))) + state.attributes.getDictionary(context), state.getRawProperties(), + state.regions, state.types))) return success(); // Diagnostic emitted by interface. From 723e8274663455a33c32c8c30fe2151f3e793966 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 1 May 2023 15:55:58 -0700 Subject: [PATCH 451/915] Revert "Introduce MLIR Op Properties" This reverts commit 3ff828a652a7cfb8eeb0dc505b3948300eed40fe. Some bots are broken and investigation is needed before relanding. --- mlir/lib/CAPI/IR/IR.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6ed32e1ce..0069bf102 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -340,8 +340,7 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (succeeded(inferInterface->inferReturnTypes( context, state.location, state.operands, - state.attributes.getDictionary(context), state.getRawProperties(), - state.regions, state.types))) + state.attributes.getDictionary(context), state.regions, state.types))) return success(); // Diagnostic emitted by interface. From e766e8126bd575922b7872e62d4bb4b09f31bc5a Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sun, 26 Feb 2023 10:46:01 -0500 Subject: [PATCH 452/915] Introduce MLIR Op Properties This new features enabled to dedicate custom storage inline within operations. This storage can be used as an alternative to attributes to store data that is specific to an operation. Attribute can also be stored inside the properties storage if desired, but any kind of data can be present as well. This offers a way to store and mutate data without uniquing in the Context like Attribute. See the OpPropertiesTest.cpp for an example where a struct with a std::vector<> is attached to an operation and mutated in-place: struct TestProperties { int a = -1; float b = -1.; std::vector array = {-33}; }; More complex scheme (including reference-counting) are also possible. The only constraint to enable storing a C++ object as "properties" on an operation is to implement three functions: - convert from the candidate object to an Attribute - convert from the Attribute to the candidate object - hash the object Optional the parsing and printing can also be customized with 2 extra functions. A new options is introduced to ODS to allow dialects to specify: let usePropertiesForAttributes = 1; When set to true, the inherent attributes for all the ops in this dialect will be using properties instead of being stored alongside discardable attributes. The TestDialect showcases this feature. Another change is that we introduce new APIs on the Operation class to access separately the inherent attributes from the discardable ones. We envision deprecating and removing the `getAttr()`, `getAttrsDictionary()`, and other similar method which don't make the distinction explicit, leading to an entirely separate namespace for discardable attributes. Recommit 3ff828a652a7 after fixing python bindings build. Differential Revision: https://reviews.llvm.org/D141742 --- mlir/include/mlir-c/Interfaces.h | 4 ++-- mlir/lib/Bindings/Python/IRInterfaces.cpp | 7 ++++--- mlir/lib/CAPI/IR/IR.cpp | 3 ++- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 8 +++++--- 4 files changed, 13 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index 233f828b9..405e2bb71 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -57,8 +57,8 @@ typedef void (*MlirTypesCallback)(intptr_t, MlirType *, void *); MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, - intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, - void *userData); + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirTypesCallback callback, void *userData); #ifdef __cplusplus } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index c8371dcc7..766d6f3e4 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -187,7 +187,7 @@ class PyInferTypeOpInterface /// return types. Throws value_error on failure. std::vector inferReturnTypes(std::optional operandList, - std::optional attributes, + std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { @@ -255,7 +255,7 @@ class PyInferTypeOpInterface MlirLogicalResult result = mlirInferTypeOpInterfaceInferReturnTypes( opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), - mlirOperands.data(), attributeDict, mlirRegions.size(), + mlirOperands.data(), attributeDict, properties, mlirRegions.size(), mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { @@ -268,7 +268,8 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("attributes") = py::none(), + py::arg("properties") = py::none(), py::arg("regions") = py::none(), py::arg("context") = py::none(), py::arg("loc") = py::none(), inferReturnTypesDoc); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 0069bf102..6ed32e1ce 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -340,7 +340,8 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (succeeded(inferInterface->inferReturnTypes( context, state.location, state.operands, - state.attributes.getDictionary(context), state.regions, state.types))) + state.attributes.getDictionary(context), state.getRawProperties(), + state.regions, state.types))) return success(); // Diagnostic emitted by interface. diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 5adccbdaf..029feed3a 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -1,3 +1,5 @@ + + //===- Interfaces.cpp - C Interface for MLIR Interfaces -------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. @@ -39,8 +41,8 @@ MlirTypeID mlirInferTypeOpInterfaceTypeID() { MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( MlirStringRef opName, MlirContext context, MlirLocation location, intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, - intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, - void *userData) { + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirTypesCallback callback, void *userData) { StringRef name(opName.data, opName.length); std::optional info = RegisteredOperationName::lookup(name, unwrap(context)); @@ -72,7 +74,7 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( SmallVector inferredTypes; if (failed(info->getInterface()->inferReturnTypes( unwrap(context), maybeLocation, unwrappedOperands, attributeDict, - unwrappedRegions, inferredTypes))) + properties, unwrappedRegions, inferredTypes))) return mlirLogicalResultFailure(); SmallVector wrappedInferredTypes; From 224ce5f1d118d77afa77425d3b2e85916a04ad7c Mon Sep 17 00:00:00 2001 From: max Date: Tue, 2 May 2023 15:52:04 -0500 Subject: [PATCH 453/915] [MLIR][python bindings] Add support for DenseElementsAttr of IndexType Differential Revision: https://reviews.llvm.org/D149690 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 5e7138b21..22001957f 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -710,6 +710,10 @@ class PyDenseElementsAttribute // f16 return bufferInfo(shapedType, "e"); } + if (mlirTypeIsAIndex(elementType)) { + // Same as IndexType::kInternalStorageBitWidth + return bufferInfo(shapedType); + } if (mlirTypeIsAInteger(elementType) && mlirIntegerTypeGetWidth(elementType) == 32) { if (mlirIntegerTypeIsSignless(elementType) || From 6e09a0f1e2458d9c495bf28c8de28e7e5b96ec59 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 4 May 2023 14:13:06 +0000 Subject: [PATCH 454/915] [mlir] make transform.loop.outline also return the call handle Outlining is particularly interesting when the outlined function is replaced with something else, e.g., a microkernel. It is good to have a handle to the call in this case. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D149849 --- mlir/python/mlir/dialects/_loop_transform_ops_ext.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index a275ea615..10079d32f 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -39,7 +39,8 @@ class LoopOutlineOp: def __init__( self, - result_type: Type, + function_type: Type, + call_type: Type, target: Union[Operation, Value], *, func_name: Union[str, StringAttr], @@ -47,7 +48,8 @@ def __init__( loc=None, ): super().__init__( - result_type, + function_type, + call_type, _get_op_result_or_value(target), func_name=(func_name if isinstance(func_name, StringAttr) else StringAttr.get(func_name)), From fe0fa03e58673158aa71a21be3ffba4ed3fd7209 Mon Sep 17 00:00:00 2001 From: max Date: Sun, 7 May 2023 18:19:46 -0500 Subject: [PATCH 455/915] [MLIR][python bindings] Add `PyValue.print_as_operand` (`Value::printAsOperand`) Useful for easier debugging (no need to regex out all of the stuff around the id). Differential Revision: https://reviews.llvm.org/D149902 --- mlir/include/mlir-c/IR.h | 6 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 17 +++++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ 3 files changed, 31 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 90af14461..13a3cb013 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -776,6 +776,12 @@ MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); MLIR_CAPI_EXPORTED void mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); +/// Prints a value as an operand (i.e., the ValueID). +MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, + MlirOpPrintingFlags flags, + MlirStringCallback callback, + void *userData); + /// Returns an op operand representing the first use of the value, or a null op /// operand if there are no uses. MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index f2e188e78..7ffa46400 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -156,6 +156,10 @@ position in the argument list. If the value is an operation result, this is equivalent to printing the operation that produced it. )"; +static const char kGetNameAsOperand[] = + R"(Returns the string form of value as an operand (i.e., the ValueID). +)"; + static const char kValueReplaceAllUsesWithDocstring[] = R"(Replace all uses of value with the new value, updating anything in the IR that uses 'self' to use the other value instead. @@ -3336,6 +3340,19 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, kValueDunderStrDocstring) + .def( + "get_name", + [](PyValue &self, bool useLocalScope) { + PyPrintAccumulator printAccum; + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(), + printAccum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); + return printAccum.join(); + }, + py::arg("use_local_scope") = false, kGetNameAsOperand) .def_property_readonly("type", [](PyValue &self) { return PyType( diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6ed32e1ce..79386dedf 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -767,6 +768,13 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } +void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + Value cppValue = unwrap(value); + cppValue.printAsOperand(stream, *unwrap(flags)); +} + MlirOpOperand mlirValueGetFirstUse(MlirValue value) { Value cppValue = unwrap(value); if (cppValue.use_empty()) From d19c93c63e7c1b84629c11a361e536f784002776 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Sun, 7 May 2023 23:56:04 -0400 Subject: [PATCH 456/915] [mlir][python] Allow specifying block arg locations Currently blocks are always created with UnknownLoc's for their arguments. This adds an `arg_locs` argument to all block creation APIs, which takes an optional sequence of locations to use, one per block argument. If no locations are supplied, the current Location context is used. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150084 --- mlir/lib/Bindings/Python/IRCore.cpp | 105 +++++++++------------ mlir/python/mlir/dialects/_func_ops_ext.py | 4 +- 2 files changed, 47 insertions(+), 62 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7ffa46400..2158a4cb5 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -193,6 +193,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +/// Create a block, using the current location context if no locations are +/// specified. +static MlirBlock createBlock(const py::sequence &pyArgTypes, + const std::optional &pyArgLocs) { + SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (const auto &pyType : pyArgTypes) + argTypes.push_back(pyType.cast()); + + SmallVector argLocs; + if (pyArgLocs) { + argLocs.reserve(pyArgLocs->size()); + for (const auto &pyLoc : *pyArgLocs) + argLocs.push_back(pyLoc.cast()); + } else if (!argTypes.empty()) { + argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); + } + + if (argTypes.size() != argLocs.size()) + throw py::value_error(("Expected " + Twine(argTypes.size()) + + " locations, got: " + Twine(argLocs.size())) + .str()); + return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); +} + /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } @@ -364,21 +389,10 @@ class PyBlockList { throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes) { + PyBlock appendBlock(const py::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = - mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } @@ -388,7 +402,8 @@ class PyBlockList { .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + py::arg("arg_locs") = std::nullopt); } private: @@ -2966,27 +2981,17 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, py::list pyArgTypes) { + [](PyRegion &parent, const py::list &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, py::arg("parent"), py::arg("arg_types") = py::list(), + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " - "region (with given argument types).") + "region (with given argument types and locations).") .def( "append_to", [](PyBlock &self, PyRegion ®ion) { @@ -2998,50 +3003,30 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "__iter__", [](PyBlock &self) { diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 79577463d..56df423d3 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -90,7 +90,7 @@ def entry_block(self): raise IndexError('External function does not have a body') return self.regions[0].blocks[0] - def add_entry_block(self): + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): """ Add an entry block to the function body using the function signature to infer block arguments. @@ -98,7 +98,7 @@ def add_entry_block(self): """ if not self.is_external: raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) return self.body.blocks[0] @property From 6d453d1c8cdae7254de0498b018f3690caece8ad Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 9 May 2023 18:01:44 -0400 Subject: [PATCH 457/915] Revert "[mlir][python] Allow specifying block arg locations" This reverts commit d19c93c63e7c1b84629c11a361e536f784002776. This caused a buildbot failure: https://lab.llvm.org/buildbot/#/builders/61/builds/43479 --- mlir/lib/Bindings/Python/IRCore.cpp | 105 ++++++++++++--------- mlir/python/mlir/dialects/_func_ops_ext.py | 4 +- 2 files changed, 62 insertions(+), 47 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2158a4cb5..7ffa46400 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -193,31 +193,6 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -/// Create a block, using the current location context if no locations are -/// specified. -static MlirBlock createBlock(const py::sequence &pyArgTypes, - const std::optional &pyArgLocs) { - SmallVector argTypes; - argTypes.reserve(pyArgTypes.size()); - for (const auto &pyType : pyArgTypes) - argTypes.push_back(pyType.cast()); - - SmallVector argLocs; - if (pyArgLocs) { - argLocs.reserve(pyArgLocs->size()); - for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(pyLoc.cast()); - } else if (!argTypes.empty()) { - argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); - } - - if (argTypes.size() != argLocs.size()) - throw py::value_error(("Expected " + Twine(argTypes.size()) + - " locations, got: " + Twine(argLocs.size())) - .str()); - return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); -} - /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } @@ -389,10 +364,21 @@ class PyBlockList { throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + PyBlock appendBlock(const py::args &pyArgTypes) { operation->checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + llvm::SmallVector argTypes; + llvm::SmallVector argLocs; + argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); + } + + MlirBlock block = + mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } @@ -402,8 +388,7 @@ class PyBlockList { .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, - py::arg("arg_locs") = std::nullopt); + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); } private: @@ -2981,17 +2966,27 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, const py::list &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyRegion &parent, py::list pyArgTypes) { parent.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + llvm::SmallVector argTypes; + llvm::SmallVector argLocs; + argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, py::arg("parent"), py::arg("arg_types") = py::list(), - py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " - "region (with given argument types and locations).") + "region (with given argument types).") .def( "append_to", [](PyBlock &self, PyRegion ®ion) { @@ -3003,30 +2998,50 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, py::args pyArgTypes) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + llvm::SmallVector argTypes; + llvm::SmallVector argLocs; + argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); + } + + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " - "(with given argument types and locations).") + "(with given argument types).") .def( "create_after", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, py::args pyArgTypes) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + llvm::SmallVector argTypes; + llvm::SmallVector argLocs; + argTypes.reserve(pyArgTypes.size()); + argLocs.reserve(pyArgTypes.size()); + for (auto &pyArg : pyArgTypes) { + argTypes.push_back(pyArg.cast()); + + // TODO: Pass in a proper location here. + argLocs.push_back( + mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); + } + MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), + argLocs.data()); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " - "(with given argument types and locations).") + "(with given argument types).") .def( "__iter__", [](PyBlock &self) { diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 56df423d3..79577463d 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -90,7 +90,7 @@ def entry_block(self): raise IndexError('External function does not have a body') return self.regions[0].blocks[0] - def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + def add_entry_block(self): """ Add an entry block to the function body using the function signature to infer block arguments. @@ -98,7 +98,7 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): """ if not self.is_external: raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + self.body.blocks.append(*self.type.inputs) return self.body.blocks[0] @property From 480f26fbef87f9d3f7c24babc1b0a6da5f9d3c8d Mon Sep 17 00:00:00 2001 From: Arash Taheri-Dezfouli Date: Thu, 11 May 2023 14:29:16 -0500 Subject: [PATCH 458/915] [MLIR] Add InferShapedTypeOpInterface bindings Add C and python bindings for InferShapedTypeOpInterface and ShapedTypeComponents. This allows users to invoke InferShapedTypeOpInterface for ops that implement it. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D149494 --- mlir/include/mlir-c/Interfaces.h | 27 ++ mlir/lib/Bindings/Python/IRInterfaces.cpp | 305 ++++++++++++++++++---- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 127 +++++++-- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 24 +- 4 files changed, 403 insertions(+), 80 deletions(-) diff --git a/mlir/include/mlir-c/Interfaces.h b/mlir/include/mlir-c/Interfaces.h index 405e2bb71..a5a3473ea 100644 --- a/mlir/include/mlir-c/Interfaces.h +++ b/mlir/include/mlir-c/Interfaces.h @@ -60,6 +60,33 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( void *properties, intptr_t nRegions, MlirRegion *regions, MlirTypesCallback callback, void *userData); +//===----------------------------------------------------------------------===// +// InferShapedTypeOpInterface. +//===----------------------------------------------------------------------===// + +/// Returns the interface TypeID of the InferShapedTypeOpInterface. +MLIR_CAPI_EXPORTED MlirTypeID mlirInferShapedTypeOpInterfaceTypeID(); + +/// These callbacks are used to return multiple shaped type components from +/// functions while transferring ownership to the caller. The first argument is +/// the has rank boolean followed by the the rank and a pointer to the shape +/// (if applicable). The next argument is the element type, then the attribute. +/// The last argument is an opaque pointer forwarded to the callback by the +/// caller. This callback will be called potentially multiple times for each +/// shaped type components. +typedef void (*MlirShapedTypeComponentsCallback)(bool, intptr_t, + const int64_t *, MlirType, + MlirAttribute, void *); + +/// Infers the return shaped type components of the operation. Calls `callback` +/// with the types of inferred arguments on success. Returns failure otherwise. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 766d6f3e4..0a7a25c00 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" @@ -35,6 +35,83 @@ constexpr static const char *inferReturnTypesDoc = R"(Given the arguments required to build an operation, attempts to infer its return types. Raises ValueError on failure.)"; +constexpr static const char *inferReturnTypeComponentsDoc = + R"(Given the arguments required to build an operation, attempts to infer +its return shaped type components. Raises ValueError on failure.)"; + +namespace { + +/// Takes in an optional ist of operands and converts them into a SmallVector +/// of MlirVlaues. Returns an empty SmallVector if the list is empty. +llvm::SmallVector wrapOperands(std::optional operandList) { + llvm::SmallVector mlirOperands; + + if (!operandList || operandList->empty()) { + return mlirOperands; + } + + // Note: as the list may contain other lists this may not be final size. + mlirOperands.reserve(operandList->size()); + for (const auto &&it : llvm::enumerate(*operandList)) { + PyValue *val; + try { + val = py::cast(it.value()); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + continue; + } catch (py::cast_error &err) { + // Intentionally unhandled to try sequence below first. + (void)err; + } + + try { + auto vals = py::cast(it.value()); + for (py::object v : vals) { + try { + val = py::cast(v); + if (!val) + throw py::cast_error(); + mlirOperands.push_back(val->get()); + } catch (py::cast_error &err) { + throw py::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + err.what() + ")") + .str()); + } + } + continue; + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " must be a Value or Sequence of Values (" + + err.what() + ")") + .str()); + } + + throw py::cast_error(); + } + + return mlirOperands; +} + +/// Takes in an optional vector of PyRegions and returns a SmallVector of +/// MlirRegion. Returns an empty SmallVector if the list is empty. +llvm::SmallVector +wrapRegions(std::optional> regions) { + llvm::SmallVector mlirRegions; + + if (regions) { + mlirRegions.reserve(regions->size()); + for (PyRegion ®ion : *regions) { + mlirRegions.push_back(region); + } + } + + return mlirRegions; +} + +} // namespace + /// CRTP base class for Python classes representing MLIR Op interfaces. /// Interface hierarchies are flat so no base class is expected here. The /// derived class is expected to define the following static fields: @@ -104,7 +181,7 @@ class PyConcreteOpInterface { /// Creates the Python bindings for this class in the given module. static void bind(py::module &m) { - py::class_ cls(m, "InferTypeOpInterface", + py::class_ cls(m, ConcreteIface::pyClassName, py::module_local()); cls.def(py::init(), py::arg("object"), py::arg("context") = py::none(), constructorDoc) @@ -155,7 +232,7 @@ class PyConcreteOpInterface { py::object obj; }; -/// Python wrapper for InterTypeOpInterface. This interface has only static +/// Python wrapper for InferTypeOpInterface. This interface has only static /// methods. class PyInferTypeOpInterface : public PyConcreteOpInterface { @@ -191,59 +268,8 @@ class PyInferTypeOpInterface std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { - llvm::SmallVector mlirOperands; - llvm::SmallVector mlirRegions; - - if (operandList && !operandList->empty()) { - // Note: as the list may contain other lists this may not be final size. - mlirOperands.reserve(operandList->size()); - for (const auto& it : llvm::enumerate(*operandList)) { - PyValue* val; - try { - val = py::cast(it.value()); - if (!val) - throw py::cast_error(); - mlirOperands.push_back(val->get()); - continue; - } catch (py::cast_error &err) { - // Intentionally unhandled to try sequence below first. - (void)err; - } - - try { - auto vals = py::cast(it.value()); - for (py::object v : vals) { - try { - val = py::cast(v); - if (!val) - throw py::cast_error(); - mlirOperands.push_back(val->get()); - } catch (py::cast_error &err) { - throw py::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " must be a Value or Sequence of Values (" + err.what() + - ")") - .str()); - } - } - continue; - } catch (py::cast_error &err) { - throw py::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); - } - - throw py::cast_error(); - } - } - - if (regions) { - mlirRegions.reserve(regions->size()); - for (PyRegion ®ion : *regions) { - mlirRegions.push_back(region); - } - } + llvm::SmallVector mlirOperands = wrapOperands(operandList); + llvm::SmallVector mlirRegions = wrapRegions(regions); std::vector inferredTypes; PyMlirContext &pyContext = context.resolve(); @@ -275,7 +301,172 @@ class PyInferTypeOpInterface } }; -void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); } +/// Wrapper around an shaped type components. +class PyShapedTypeComponents { +public: + PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} + PyShapedTypeComponents(py::list shape, MlirType elementType) + : shape(shape), elementType(elementType), ranked(true) {} + PyShapedTypeComponents(py::list shape, MlirType elementType, + MlirAttribute attribute) + : shape(shape), elementType(elementType), attribute(attribute), + ranked(true) {} + PyShapedTypeComponents(PyShapedTypeComponents &) = delete; + PyShapedTypeComponents(PyShapedTypeComponents &&other) + : shape(other.shape), elementType(other.elementType), + attribute(other.attribute), ranked(other.ranked) {} + + static void bind(py::module &m) { + py::class_(m, "ShapedTypeComponents", + py::module_local()) + .def_property_readonly( + "element_type", + [](PyShapedTypeComponents &self) { + return PyType(PyMlirContext::forContext( + mlirTypeGetContext(self.elementType)), + self.elementType); + }, + "Returns the element type of the shaped type components.") + .def_static( + "get", + [](PyType &elementType) { + return PyShapedTypeComponents(elementType); + }, + py::arg("element_type"), + "Create an shaped type components object with only the element " + "type.") + .def_static( + "get", + [](py::list shape, PyType &elementType) { + return PyShapedTypeComponents(shape, elementType); + }, + py::arg("shape"), py::arg("element_type"), + "Create a ranked shaped type components object.") + .def_static( + "get", + [](py::list shape, PyType &elementType, PyAttribute &attribute) { + return PyShapedTypeComponents(shape, elementType, attribute); + }, + py::arg("shape"), py::arg("element_type"), py::arg("attribute"), + "Create a ranked shaped type components object with attribute.") + .def_property_readonly( + "has_rank", + [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, + "Returns whether the given shaped type component is ranked.") + .def_property_readonly( + "rank", + [](PyShapedTypeComponents &self) -> py::object { + if (!self.ranked) { + return py::none(); + } + return py::int_(self.shape.size()); + }, + "Returns the rank of the given ranked shaped type components. If " + "the shaped type components does not have a rank, None is " + "returned.") + .def_property_readonly( + "shape", + [](PyShapedTypeComponents &self) -> py::object { + if (!self.ranked) { + return py::none(); + } + return py::list(self.shape); + }, + "Returns the shape of the ranked shaped type components as a list " + "of integers. Returns none if the shaped type component does not " + "have a rank."); + } + + pybind11::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + +private: + py::list shape; + MlirType elementType; + MlirAttribute attribute; + bool ranked{false}; +}; + +/// Python wrapper for InferShapedTypeOpInterface. This interface has only +/// static methods. +class PyInferShapedTypeOpInterface + : public PyConcreteOpInterface { +public: + using PyConcreteOpInterface< + PyInferShapedTypeOpInterface>::PyConcreteOpInterface; + + constexpr static const char *pyClassName = "InferShapedTypeOpInterface"; + constexpr static GetTypeIDFunctionTy getInterfaceID = + &mlirInferShapedTypeOpInterfaceTypeID; + + /// C-style user-data structure for type appending callback. + struct AppendResultsCallbackData { + std::vector &inferredShapedTypeComponents; + }; + + /// Appends the shaped type components provided as unpacked shape, element + /// type, attribute to the user-data. + static void appendResultsCallback(bool hasRank, intptr_t rank, + const int64_t *shape, MlirType elementType, + MlirAttribute attribute, void *userData) { + auto *data = static_cast(userData); + if (!hasRank) { + data->inferredShapedTypeComponents.emplace_back(elementType); + } else { + py::list shapeList; + for (intptr_t i = 0; i < rank; ++i) { + shapeList.append(shape[i]); + } + data->inferredShapedTypeComponents.emplace_back(shapeList, elementType, + attribute); + } + } + + /// Given the arguments required to build an operation, attempts to infer the + /// shaped type components. Throws value_error on failure. + std::vector inferReturnTypeComponents( + std::optional operandList, + std::optional attributes, void *properties, + std::optional> regions, + DefaultingPyMlirContext context, DefaultingPyLocation location) { + llvm::SmallVector mlirOperands = wrapOperands(operandList); + llvm::SmallVector mlirRegions = wrapRegions(regions); + + std::vector inferredShapedTypeComponents; + PyMlirContext &pyContext = context.resolve(); + AppendResultsCallbackData data{inferredShapedTypeComponents}; + MlirStringRef opNameRef = + mlirStringRefCreate(getOpName().data(), getOpName().length()); + MlirAttribute attributeDict = + attributes ? attributes->get() : mlirAttributeGetNull(); + + MlirLogicalResult result = mlirInferShapedTypeOpInterfaceInferReturnTypes( + opNameRef, pyContext.get(), location.resolve(), mlirOperands.size(), + mlirOperands.data(), attributeDict, properties, mlirRegions.size(), + mlirRegions.data(), &appendResultsCallback, &data); + + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error("Failed to infer result shape type components"); + } + + return inferredShapedTypeComponents; + } + + static void bindDerived(ClassTy &cls) { + cls.def("inferReturnTypeComponents", + &PyInferShapedTypeOpInterface::inferReturnTypeComponents, + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("properties") = py::none(), py::arg("context") = py::none(), + py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); + } +}; + +void populateIRInterfaces(py::module &m) { + PyInferTypeOpInterface::bind(m); + PyShapedTypeComponents::bind(m); + PyInferShapedTypeOpInterface::bind(m); +} } // namespace python } // namespace mlir diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 029feed3a..e597a7bcb 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -11,14 +11,65 @@ #include "mlir-c/Interfaces.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Interfaces.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/IR/ValueRange.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "llvm/ADT/ScopeExit.h" #include using namespace mlir; +namespace { + +std::optional +getRegisteredOperationName(MlirContext context, MlirStringRef opName) { + StringRef name(opName.data, opName.length); + std::optional info = + RegisteredOperationName::lookup(name, unwrap(context)); + return info; +} + +std::optional maybeGetLocation(MlirLocation location) { + std::optional maybeLocation; + if (!mlirLocationIsNull(location)) + maybeLocation = unwrap(location); + return maybeLocation; +} + +SmallVector unwrapOperands(intptr_t nOperands, MlirValue *operands) { + SmallVector unwrappedOperands; + (void)unwrapList(nOperands, operands, unwrappedOperands); + return unwrappedOperands; +} + +DictionaryAttr unwrapAttributes(MlirAttribute attributes) { + DictionaryAttr attributeDict; + if (!mlirAttributeIsNull(attributes)) + attributeDict = unwrap(attributes).cast(); + return attributeDict; +} + +SmallVector> unwrapRegions(intptr_t nRegions, + MlirRegion *regions) { + // Create a vector of unique pointers to regions and make sure they are not + // deleted when exiting the scope. This is a hack caused by C++ API expecting + // an list of unique pointers to regions (without ownership transfer + // semantics) and C API making ownership transfer explicit. + SmallVector> unwrappedRegions; + unwrappedRegions.reserve(nRegions); + for (intptr_t i = 0; i < nRegions; ++i) + unwrappedRegions.emplace_back(unwrap(*(regions + i))); + auto cleaner = llvm::make_scope_exit([&]() { + for (auto ®ion : unwrappedRegions) + region.release(); + }); + return unwrappedRegions; +} + +} // namespace + bool mlirOperationImplementsInterface(MlirOperation operation, MlirTypeID interfaceTypeID) { std::optional info = @@ -45,31 +96,15 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( MlirTypesCallback callback, void *userData) { StringRef name(opName.data, opName.length); std::optional info = - RegisteredOperationName::lookup(name, unwrap(context)); + getRegisteredOperationName(context, opName); if (!info) return mlirLogicalResultFailure(); - std::optional maybeLocation; - if (!mlirLocationIsNull(location)) - maybeLocation = unwrap(location); - SmallVector unwrappedOperands; - (void)unwrapList(nOperands, operands, unwrappedOperands); - DictionaryAttr attributeDict; - if (!mlirAttributeIsNull(attributes)) - attributeDict = unwrap(attributes).cast(); - - // Create a vector of unique pointers to regions and make sure they are not - // deleted when exiting the scope. This is a hack caused by C++ API expecting - // an list of unique pointers to regions (without ownership transfer - // semantics) and C API making ownership transfer explicit. - SmallVector> unwrappedRegions; - unwrappedRegions.reserve(nRegions); - for (intptr_t i = 0; i < nRegions; ++i) - unwrappedRegions.emplace_back(unwrap(*(regions + i))); - auto cleaner = llvm::make_scope_exit([&]() { - for (auto ®ion : unwrappedRegions) - region.release(); - }); + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); SmallVector inferredTypes; if (failed(info->getInterface()->inferReturnTypes( @@ -84,3 +119,51 @@ MlirLogicalResult mlirInferTypeOpInterfaceInferReturnTypes( callback(wrappedInferredTypes.size(), wrappedInferredTypes.data(), userData); return mlirLogicalResultSuccess(); } + +MlirTypeID mlirInferShapedTypeOpInterfaceTypeID() { + return wrap(InferShapedTypeOpInterface::getInterfaceID()); +} + +MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( + MlirStringRef opName, MlirContext context, MlirLocation location, + intptr_t nOperands, MlirValue *operands, MlirAttribute attributes, + void *properties, intptr_t nRegions, MlirRegion *regions, + MlirShapedTypeComponentsCallback callback, void *userData) { + std::optional info = + getRegisteredOperationName(context, opName); + if (!info) + return mlirLogicalResultFailure(); + + std::optional maybeLocation = maybeGetLocation(location); + SmallVector unwrappedOperands = unwrapOperands(nOperands, operands); + DictionaryAttr attributeDict = unwrapAttributes(attributes); + SmallVector> unwrappedRegions = + unwrapRegions(nRegions, regions); + + SmallVector inferredTypeComponents; + if (failed(info->getInterface() + ->inferReturnTypeComponents( + unwrap(context), maybeLocation, + mlir::ValueRange(llvm::ArrayRef(unwrappedOperands)), + attributeDict, properties, unwrappedRegions, + inferredTypeComponents))) + return mlirLogicalResultFailure(); + + bool hasRank; + intptr_t rank; + const int64_t *shapeData; + for (ShapedTypeComponents t : inferredTypeComponents) { + if (t.hasRank()) { + hasRank = true; + rank = t.getDims().size(); + shapeData = t.getDims().data(); + } else { + hasRank = false; + rank = 0; + shapeData = nullptr; + } + callback(hasRank, rank, shapeData, wrap(t.getElementType()), + wrap(t.getAttribute()), userData); + } + return mlirLogicalResultSuccess(); +} diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 75b25bd8c..714935fe1 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -62,6 +62,7 @@ __all__ = [ "FloatAttr", "FunctionType", "IndexType", + "InferShapedTypeOpInterface", "InferTypeOpInterface", "InsertionPoint", "IntegerAttr", @@ -88,6 +89,7 @@ __all__ = [ "RegionIterator", "RegionSequence", "ShapedType", + "ShapedTypeComponents", "StringAttr", "SymbolTable", "TupleType", @@ -689,9 +691,17 @@ class IndexType(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +class InferShapedTypeOpInterface: + def __init__(self, object: object, context: Optional[Context] = None) -> None: ... + def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ... + @property + def operation(self) -> Operation: ... + @property + def opview(self) -> OpView: ... + class InferTypeOpInterface: def __init__(self, object: object, context: Optional[Context] = None) -> None: ... - def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... + def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... @property def operation(self) -> Operation: ... @property @@ -1016,6 +1026,18 @@ class ShapedType(Type): @property def shape(self) -> List[int]: ... +class ShapedTypeComponents: + @property + def element_type(self) -> Type: ... + @staticmethod + def get(*args, **kwargs) -> ShapedTypeComponents: ... + @property + def has_rank(self) -> bool: ... + @property + def rank(self) -> int: ... + @property + def shape(self) -> List[int]: ... + # TODO: Auto-generated. Audit and fix. class StringAttr(Attribute): def __init__(self, cast_from_attr: Attribute) -> None: ... From 68263871a147d71a588ee6da139bfeb1b62bef4c Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Mon, 8 May 2023 16:33:54 +0200 Subject: [PATCH 459/915] [mlir] Move casting calls from methods to function calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Caveats include: - This clang-tidy script probably has more problems. - This only touches C++ code, so nothing that is being generated. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This first patch was created with the following steps. The intention is to only do automated changes at first, so I waste less time if it's reverted, and so the first mass change is more clear as an example to other teams that will need to follow similar steps. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. 4. Some changes have been deleted for the following reasons: - Some files had a variable also named cast - Some files had not included a header file that defines the cast functions - Some files are definitions of the classes that have the casting methods, so the code still refers to the method instead of the function without adding a prefix or removing the method declaration at the same time. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc git restore mlir/lib/IR mlir/lib/Dialect/DLTI/DLTI.cpp\ mlir/lib/Dialect/Complex/IR/ComplexDialect.cpp\ mlir/lib/**/IR/\ mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp\ mlir/lib/Dialect/Vector/Transforms/LowerVectorMultiReduction.cpp\ mlir/test/lib/Dialect/Test/TestTypes.cpp\ mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp\ mlir/test/lib/Dialect/Test/TestAttributes.cpp\ mlir/unittests/TableGen/EnumsGenTest.cpp\ mlir/test/python/lib/PythonTestCAPI.cpp\ mlir/include/mlir/IR/ ``` Differential Revision: https://reviews.llvm.org/D150123 --- mlir/lib/CAPI/Dialect/PDL.cpp | 14 +++--- mlir/lib/CAPI/Dialect/Quant.cpp | 69 ++++++++++++-------------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 15 +++--- mlir/lib/CAPI/Dialect/Transform.cpp | 6 +-- 4 files changed, 48 insertions(+), 56 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/PDL.cpp b/mlir/lib/CAPI/Dialect/PDL.cpp index 497b2cb1f..bd8b13c65 100644 --- a/mlir/lib/CAPI/Dialect/PDL.cpp +++ b/mlir/lib/CAPI/Dialect/PDL.cpp @@ -21,7 +21,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(PDL, pdl, pdl::PDLDialect) //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } //===---------------------------------------------------------------------===// @@ -29,7 +29,7 @@ bool mlirTypeIsAPDLType(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLAttributeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { @@ -41,7 +41,7 @@ MlirType mlirPDLAttributeTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLOperationTypeGet(MlirContext ctx) { @@ -53,7 +53,7 @@ MlirType mlirPDLOperationTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLRangeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLRangeTypeGet(MlirType elementType) { @@ -61,7 +61,7 @@ MlirType mlirPDLRangeTypeGet(MlirType elementType) { } MlirType mlirPDLRangeTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(cast(unwrap(type)).getElementType()); } //===---------------------------------------------------------------------===// @@ -69,7 +69,7 @@ MlirType mlirPDLRangeTypeGetElementType(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLTypeType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLTypeTypeGet(MlirContext ctx) { @@ -81,7 +81,7 @@ MlirType mlirPDLTypeTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsAPDLValueType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirPDLValueTypeGet(MlirContext ctx) { diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 065ab3e36..0a7181d8b 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -20,7 +20,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) //===---------------------------------------------------------------------===// bool mlirTypeIsAQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } unsigned mlirQuantizedTypeGetSignedFlag() { @@ -40,39 +40,37 @@ int64_t mlirQuantizedTypeGetDefaultMaximumForInteger(bool isSigned, } MlirType mlirQuantizedTypeGetExpressedType(MlirType type) { - return wrap(unwrap(type).cast().getExpressedType()); + return wrap(cast(unwrap(type)).getExpressedType()); } unsigned mlirQuantizedTypeGetFlags(MlirType type) { - return unwrap(type).cast().getFlags(); + return cast(unwrap(type)).getFlags(); } bool mlirQuantizedTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return cast(unwrap(type)).isSigned(); } MlirType mlirQuantizedTypeGetStorageType(MlirType type) { - return wrap(unwrap(type).cast().getStorageType()); + return wrap(cast(unwrap(type)).getStorageType()); } int64_t mlirQuantizedTypeGetStorageTypeMin(MlirType type) { - return unwrap(type).cast().getStorageTypeMin(); + return cast(unwrap(type)).getStorageTypeMin(); } int64_t mlirQuantizedTypeGetStorageTypeMax(MlirType type) { - return unwrap(type).cast().getStorageTypeMax(); + return cast(unwrap(type)).getStorageTypeMax(); } unsigned mlirQuantizedTypeGetStorageTypeIntegralWidth(MlirType type) { - return unwrap(type) - .cast() - .getStorageTypeIntegralWidth(); + return cast(unwrap(type)).getStorageTypeIntegralWidth(); } bool mlirQuantizedTypeIsCompatibleExpressedType(MlirType type, MlirType candidate) { - return unwrap(type).cast().isCompatibleExpressedType( - unwrap(candidate)); + return cast(unwrap(type)) + .isCompatibleExpressedType(unwrap(candidate)); } MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { @@ -81,19 +79,19 @@ MlirType mlirQuantizedTypeGetQuantizedElementType(MlirType type) { MlirType mlirQuantizedTypeCastFromStorageType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromStorageType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToStorageType(MlirType type) { return wrap(quant::QuantizedType::castToStorageType( - unwrap(type).cast())); + cast(unwrap(type)))); } MlirType mlirQuantizedTypeCastFromExpressedType(MlirType type, MlirType candidate) { - return wrap(unwrap(type).cast().castFromExpressedType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castFromExpressedType(unwrap(candidate))); } MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { @@ -102,9 +100,8 @@ MlirType mlirQuantizedTypeCastToExpressedType(MlirType type) { MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, MlirType candidate) { - return wrap( - unwrap(type).cast().castExpressedToStorageType( - unwrap(candidate))); + return wrap(cast(unwrap(type)) + .castExpressedToStorageType(unwrap(candidate))); } //===---------------------------------------------------------------------===// @@ -112,7 +109,7 @@ MlirType mlirQuantizedTypeCastExpressedToStorageType(MlirType type, //===---------------------------------------------------------------------===// bool mlirTypeIsAAnyQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -128,7 +125,7 @@ MlirType mlirAnyQuantizedTypeGet(unsigned flags, MlirType storageType, //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, @@ -141,15 +138,15 @@ MlirType mlirUniformQuantizedTypeGet(unsigned flags, MlirType storageType, } double mlirUniformQuantizedTypeGetScale(MlirType type) { - return unwrap(type).cast().getScale(); + return cast(unwrap(type)).getScale(); } int64_t mlirUniformQuantizedTypeGetZeroPoint(MlirType type) { - return unwrap(type).cast().getZeroPoint(); + return cast(unwrap(type)).getZeroPoint(); } bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -157,7 +154,7 @@ bool mlirUniformQuantizedTypeIsFixedPoint(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsAUniformQuantizedPerAxisType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirUniformQuantizedPerAxisTypeGet( @@ -172,33 +169,29 @@ MlirType mlirUniformQuantizedPerAxisTypeGet( } intptr_t mlirUniformQuantizedPerAxisTypeGetNumDims(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales() .size(); } double mlirUniformQuantizedPerAxisTypeGetScale(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getScales()[pos]; } int64_t mlirUniformQuantizedPerAxisTypeGetZeroPoint(MlirType type, intptr_t pos) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getZeroPoints()[pos]; } int32_t mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type) { - return unwrap(type) - .cast() + return cast(unwrap(type)) .getQuantizedDimension(); } bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { - return unwrap(type).cast().isFixedPoint(); + return cast(unwrap(type)).isFixedPoint(); } //===---------------------------------------------------------------------===// @@ -206,7 +199,7 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { //===---------------------------------------------------------------------===// bool mlirTypeIsACalibratedQuantizedType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, @@ -216,9 +209,9 @@ MlirType mlirCalibratedQuantizedTypeGet(MlirType expressedType, double min, } double mlirCalibratedQuantizedTypeGetMin(MlirType type) { - return unwrap(type).cast().getMin(); + return cast(unwrap(type)).getMin(); } double mlirCalibratedQuantizedTypeGetMax(MlirType type) { - return unwrap(type).cast().getMax(); + return cast(unwrap(type)).getMax(); } diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 1aa6d329d..795ce51ff 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -42,7 +42,7 @@ static_assert( "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { - return unwrap(attr).isa(); + return isa(unwrap(attr)); } MlirAttribute mlirSparseTensorEncodingAttrGet( @@ -60,29 +60,28 @@ MlirAttribute mlirSparseTensorEncodingAttrGet( } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDimOrdering()); + return wrap(cast(unwrap(attr)).getDimOrdering()); } MlirAffineMap mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { - return wrap( - unwrap(attr).cast().getHigherOrdering()); + return wrap(cast(unwrap(attr)).getHigherOrdering()); } intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { - return unwrap(attr).cast().getLvlRank(); + return cast(unwrap(attr)).getLvlRank(); } MlirSparseTensorDimLevelType mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { return static_cast( - unwrap(attr).cast().getLvlType(lvl)); + cast(unwrap(attr)).getLvlType(lvl)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { - return unwrap(attr).cast().getPosWidth(); + return cast(unwrap(attr)).getPosWidth(); } int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { - return unwrap(attr).cast().getCrdWidth(); + return cast(unwrap(attr)).getCrdWidth(); } diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 606b301cc..90594b67a 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -22,7 +22,7 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Transform, transform, //===---------------------------------------------------------------------===// bool mlirTypeIsATransformAnyOpType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { @@ -34,7 +34,7 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { //===---------------------------------------------------------------------===// bool mlirTypeIsATransformOperationType(MlirType type) { - return unwrap(type).isa(); + return isa(unwrap(type)); } MlirType mlirTransformOperationTypeGet(MlirContext ctx, @@ -44,5 +44,5 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx, } MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { - return wrap(unwrap(type).cast().getOperationName()); + return wrap(cast(unwrap(type)).getOperationName()); } From 9ccc4107d0f8996460f796acb9b10ee52d4be91a Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Thu, 11 May 2023 11:10:46 +0200 Subject: [PATCH 460/915] [mlir] Update method cast calls to function calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: * https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" * Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This follows a previous patch that updated calls `op.cast()-> cast(op)`. However some cases could not handle an unprefixed `cast` call due to occurrences of variables named cast, or occurring inside of class definitions which would resolve to the method. All C++ files that did not work automatically with `cast()` are updated here to `llvm::cast` and similar with the intention that they can be easily updated after the methods are removed through a find-replace. See https://github.com/llvm/llvm-project/compare/main...tpopp:llvm-project:tidy-cast-check for the clang-tidy check that is used and then update printed occurrences of the function to include `llvm::` before. One can then run the following: ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -export-fixes /tmp/cast/casts.yaml mlir/*\ -header-filter=mlir/ -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D150348 --- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 253 +++++++++++++------------ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 105 +++++----- mlir/lib/CAPI/IR/IR.cpp | 18 +- 3 files changed, 199 insertions(+), 177 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 66d291edd..f2441e0b0 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -25,7 +25,7 @@ MlirAttribute mlirAttributeGetNull() { return {nullptr}; } //===----------------------------------------------------------------------===// bool mlirAttributeIsALocation(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -33,7 +33,7 @@ bool mlirAttributeIsALocation(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAAffineMap(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { @@ -41,7 +41,7 @@ MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map) { } MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -49,7 +49,7 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, @@ -61,11 +61,11 @@ MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getValue()[pos]); + return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } //===----------------------------------------------------------------------===// @@ -73,7 +73,7 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { //===----------------------------------------------------------------------===// bool mlirAttributeIsADictionary(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, @@ -87,19 +87,19 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { - return static_cast(unwrap(attr).cast().size()); + return static_cast(llvm::cast(unwrap(attr)).size()); } MlirNamedAttribute mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos) { NamedAttribute attribute = - unwrap(attr).cast().getValue()[pos]; + llvm::cast(unwrap(attr)).getValue()[pos]; return {wrap(attribute.getName()), wrap(attribute.getValue())}; } MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name) { - return wrap(unwrap(attr).cast().get(unwrap(name))); + return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } //===----------------------------------------------------------------------===// @@ -107,7 +107,7 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, //===----------------------------------------------------------------------===// bool mlirAttributeIsAFloat(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFloatAttrDoubleGet(MlirContext ctx, MlirType type, @@ -121,7 +121,7 @@ MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, MlirType type, } double mlirFloatAttrGetValueDouble(MlirAttribute attr) { - return unwrap(attr).cast().getValueAsDouble(); + return llvm::cast(unwrap(attr)).getValueAsDouble(); } //===----------------------------------------------------------------------===// @@ -129,7 +129,7 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAInteger(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { @@ -137,15 +137,15 @@ MlirAttribute mlirIntegerAttrGet(MlirType type, int64_t value) { } int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) { - return unwrap(attr).cast().getInt(); + return llvm::cast(unwrap(attr)).getInt(); } int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) { - return unwrap(attr).cast().getSInt(); + return llvm::cast(unwrap(attr)).getSInt(); } uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { - return unwrap(attr).cast().getUInt(); + return llvm::cast(unwrap(attr)).getUInt(); } //===----------------------------------------------------------------------===// @@ -153,7 +153,7 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsABool(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { @@ -161,7 +161,7 @@ MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { } bool mlirBoolAttrGetValue(MlirAttribute attr) { - return unwrap(attr).cast().getValue(); + return llvm::cast(unwrap(attr)).getValue(); } //===----------------------------------------------------------------------===// @@ -169,7 +169,7 @@ bool mlirBoolAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -177,7 +177,7 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAOpaque(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, @@ -189,11 +189,12 @@ MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(attr)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getAttrData()); + return wrap(llvm::cast(unwrap(attr)).getAttrData()); } //===----------------------------------------------------------------------===// @@ -201,7 +202,7 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAString(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { @@ -213,7 +214,7 @@ MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { } MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -221,7 +222,7 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsASymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, @@ -230,27 +231,30 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, SmallVector refs; refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) - refs.push_back(unwrap(references[i]).cast()); + refs.push_back(llvm::cast(unwrap(references[i]))); auto symbolAttr = StringAttr::get(unwrap(ctx), unwrap(symbol)); return wrap(SymbolRefAttr::get(symbolAttr, refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getRootReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getRootReference().getValue()); } MlirStringRef mlirSymbolRefAttrGetLeafReference(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getLeafReference().getValue()); + return wrap( + llvm::cast(unwrap(attr)).getLeafReference().getValue()); } intptr_t mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getNestedReferences().size()); + llvm::cast(unwrap(attr)).getNestedReferences().size()); } MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos) { - return wrap(unwrap(attr).cast().getNestedReferences()[pos]); + return wrap( + llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } //===----------------------------------------------------------------------===// @@ -258,7 +262,7 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, //===----------------------------------------------------------------------===// bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { @@ -266,7 +270,7 @@ MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -274,7 +278,7 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAType(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirTypeAttrGet(MlirType type) { @@ -282,7 +286,7 @@ MlirAttribute mlirTypeAttrGet(MlirType type) { } MlirType mlirTypeAttrGetValue(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValue()); + return wrap(llvm::cast(unwrap(attr)).getValue()); } //===----------------------------------------------------------------------===// @@ -290,7 +294,7 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAUnit(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirUnitAttrGet(MlirContext ctx) { @@ -302,24 +306,23 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirElementsAttrGetValue(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return wrap(unwrap(attr) - .cast() + return wrap(llvm::cast(unwrap(attr)) .getValues()[llvm::ArrayRef(idxs, rank)]); } bool mlirElementsAttrIsValidIndex(MlirAttribute attr, intptr_t rank, uint64_t *idxs) { - return unwrap(attr).cast().isValidIndex( - llvm::ArrayRef(idxs, rank)); + return llvm::cast(unwrap(attr)) + .isValidIndex(llvm::ArrayRef(idxs, rank)); } int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().getNumElements(); + return llvm::cast(unwrap(attr)).getNumElements(); } //===----------------------------------------------------------------------===// @@ -330,25 +333,25 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { // IsA support. bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI8Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI16Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseI64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF32Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -394,32 +397,32 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, // Accessors. intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { - return unwrap(attr).cast().size(); + return llvm::cast(unwrap(attr)).size(); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int8_t mlirDenseI8ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int16_t mlirDenseI16ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int32_t mlirDenseI32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } int64_t mlirDenseI64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } float mlirDenseF32ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast()[pos]; + return llvm::cast(unwrap(attr))[pos]; } //===----------------------------------------------------------------------===// @@ -430,13 +433,13 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { // IsA support. bool mlirAttributeIsADenseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } //===----------------------------------------------------------------------===// @@ -447,14 +450,14 @@ MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, MlirAttribute const *elements) { SmallVector attributes; return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), + DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrapList(numElements, elements, attributes))); } MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, size_t rawBufferSize, const void *rawBuffer) { - auto shapedTypeCpp = unwrap(shapedType).cast(); + auto shapedTypeCpp = llvm::cast(unwrap(shapedType)); ArrayRef rawBufferCpp(static_cast(rawBuffer), rawBufferSize); bool isSplat = false; @@ -466,61 +469,61 @@ MlirAttribute mlirDenseElementsAttrRawBufferGet(MlirType shapedType, MlirAttribute mlirDenseElementsAttrSplatGet(MlirType shapedType, MlirAttribute element) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), unwrap(element))); } MlirAttribute mlirDenseElementsAttrBoolSplatGet(MlirType shapedType, bool element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt8SplatGet(MlirType shapedType, uint8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt8SplatGet(MlirType shapedType, int8_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt32SplatGet(MlirType shapedType, uint32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt32SplatGet(MlirType shapedType, int32_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrUInt64SplatGet(MlirType shapedType, uint64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrInt64SplatGet(MlirType shapedType, int64_t element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrFloatSplatGet(MlirType shapedType, float element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrDoubleSplatGet(MlirType shapedType, double element) { - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), element)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + element)); } MlirAttribute mlirDenseElementsAttrBoolGet(MlirType shapedType, intptr_t numElements, const int *elements) { SmallVector values(elements, elements + numElements); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } /// Creates a dense attribute with elements of the type deduced by templates. @@ -528,7 +531,7 @@ template static MlirAttribute getDenseAttribute(MlirType shapedType, intptr_t numElements, const T *elements) { - return wrap(DenseElementsAttr::get(unwrap(shapedType).cast(), + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), llvm::ArrayRef(elements, numElements))); } @@ -605,99 +608,99 @@ MlirAttribute mlirDenseElementsAttrStringGet(MlirType shapedType, for (intptr_t i = 0; i < numElements; ++i) values.push_back(unwrap(strs[i])); - return wrap( - DenseElementsAttr::get(unwrap(shapedType).cast(), values)); + return wrap(DenseElementsAttr::get(llvm::cast(unwrap(shapedType)), + values)); } MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, MlirType shapedType) { - return wrap(unwrap(attr).cast().reshape( - unwrap(shapedType).cast())); + return wrap(llvm::cast(unwrap(attr)) + .reshape(llvm::cast(unwrap(shapedType)))); } //===----------------------------------------------------------------------===// // Splat accessors. bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { - return unwrap(attr).cast().isSplat(); + return llvm::cast(unwrap(attr)).isSplat(); } MlirAttribute mlirDenseElementsAttrGetSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } int mlirDenseElementsAttrGetBoolSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int8_t mlirDenseElementsAttrGetInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint8_t mlirDenseElementsAttrGetUInt8SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int32_t mlirDenseElementsAttrGetInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint32_t mlirDenseElementsAttrGetUInt32SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } int64_t mlirDenseElementsAttrGetInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } uint64_t mlirDenseElementsAttrGetUInt64SplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } float mlirDenseElementsAttrGetFloatSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } double mlirDenseElementsAttrGetDoubleSplatValue(MlirAttribute attr) { - return unwrap(attr).cast().getSplatValue(); + return llvm::cast(unwrap(attr)).getSplatValue(); } MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { return wrap( - unwrap(attr).cast().getSplatValue()); + llvm::cast(unwrap(attr)).getSplatValue()); } //===----------------------------------------------------------------------===// // Indexed accessors. bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int8_t mlirDenseElementsAttrGetInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint8_t mlirDenseElementsAttrGetUInt8Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int16_t mlirDenseElementsAttrGetInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint16_t mlirDenseElementsAttrGetUInt16Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int32_t mlirDenseElementsAttrGetInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint32_t mlirDenseElementsAttrGetUInt32Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } double mlirDenseElementsAttrGetDoubleValue(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getValues()[pos]; + return llvm::cast(unwrap(attr)).getValues()[pos]; } MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, intptr_t pos) { return wrap( - unwrap(attr).cast().getValues()[pos]); + llvm::cast(unwrap(attr)).getValues()[pos]); } //===----------------------------------------------------------------------===// @@ -705,7 +708,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getRawData().data()); + llvm::cast(unwrap(attr)).getRawData().data()); } //===----------------------------------------------------------------------===// @@ -715,7 +718,7 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { template static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { - return wrap(U::get(unwrap(shapedType).cast(), unwrap(name), + return wrap(U::get(llvm::cast(unwrap(shapedType)), unwrap(name), UnmanagedAsmResourceBlob::allocateInferAlign( llvm::ArrayRef(elements, numElements)))); } @@ -797,7 +800,7 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { - return (*unwrap(attr).cast().tryGetAsArrayRef())[pos]; + return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; } MLIR_CAPI_EXPORTED bool @@ -853,24 +856,24 @@ mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { //===----------------------------------------------------------------------===// bool mlirAttributeIsASparseElements(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirSparseElementsAttribute(MlirType shapedType, MlirAttribute denseIndices, MlirAttribute denseValues) { - return wrap( - SparseElementsAttr::get(unwrap(shapedType).cast(), - unwrap(denseIndices).cast(), - unwrap(denseValues).cast())); + return wrap(SparseElementsAttr::get( + llvm::cast(unwrap(shapedType)), + llvm::cast(unwrap(denseIndices)), + llvm::cast(unwrap(denseValues)))); } MlirAttribute mlirSparseElementsAttrGetIndices(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getIndices()); + return wrap(llvm::cast(unwrap(attr)).getIndices()); } MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { - return wrap(unwrap(attr).cast().getValues()); + return wrap(llvm::cast(unwrap(attr)).getValues()); } //===----------------------------------------------------------------------===// @@ -878,7 +881,7 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { //===----------------------------------------------------------------------===// bool mlirAttributeIsAStridedLayout(MlirAttribute attr) { - return unwrap(attr).isa(); + return llvm::isa(unwrap(attr)); } MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, @@ -889,14 +892,14 @@ MlirAttribute mlirStridedLayoutAttrGet(MlirContext ctx, int64_t offset, } int64_t mlirStridedLayoutAttrGetOffset(MlirAttribute attr) { - return unwrap(attr).cast().getOffset(); + return llvm::cast(unwrap(attr)).getOffset(); } intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { return static_cast( - unwrap(attr).cast().getStrides().size()); + llvm::cast(unwrap(attr)).getStrides().size()); } int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { - return unwrap(attr).cast().getStrides()[pos]; + return llvm::cast(unwrap(attr)).getStrides()[pos]; } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 2468c0546..90ab84760 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -23,7 +23,7 @@ using namespace mlir; //===----------------------------------------------------------------------===// bool mlirTypeIsAInteger(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirIntegerTypeGet(MlirContext ctx, unsigned bitwidth) { @@ -39,26 +39,28 @@ MlirType mlirIntegerTypeUnsignedGet(MlirContext ctx, unsigned bitwidth) { } unsigned mlirIntegerTypeGetWidth(MlirType type) { - return unwrap(type).cast().getWidth(); + return llvm::cast(unwrap(type)).getWidth(); } bool mlirIntegerTypeIsSignless(MlirType type) { - return unwrap(type).cast().isSignless(); + return llvm::cast(unwrap(type)).isSignless(); } bool mlirIntegerTypeIsSigned(MlirType type) { - return unwrap(type).cast().isSigned(); + return llvm::cast(unwrap(type)).isSigned(); } bool mlirIntegerTypeIsUnsigned(MlirType type) { - return unwrap(type).cast().isUnsigned(); + return llvm::cast(unwrap(type)).isUnsigned(); } //===----------------------------------------------------------------------===// // Index type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAIndex(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAIndex(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirIndexTypeGet(MlirContext ctx) { return wrap(IndexType::get(unwrap(ctx))); @@ -136,7 +138,9 @@ MlirType mlirF64TypeGet(MlirContext ctx) { // None type. //===----------------------------------------------------------------------===// -bool mlirTypeIsANone(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsANone(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirNoneTypeGet(MlirContext ctx) { return wrap(NoneType::get(unwrap(ctx))); @@ -147,7 +151,7 @@ MlirType mlirNoneTypeGet(MlirContext ctx) { //===----------------------------------------------------------------------===// bool mlirTypeIsAComplex(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirComplexTypeGet(MlirType elementType) { @@ -155,38 +159,41 @@ MlirType mlirComplexTypeGet(MlirType elementType) { } MlirType mlirComplexTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } //===----------------------------------------------------------------------===// // Shaped type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAShaped(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAShaped(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirShapedTypeGetElementType(MlirType type) { - return wrap(unwrap(type).cast().getElementType()); + return wrap(llvm::cast(unwrap(type)).getElementType()); } bool mlirShapedTypeHasRank(MlirType type) { - return unwrap(type).cast().hasRank(); + return llvm::cast(unwrap(type)).hasRank(); } int64_t mlirShapedTypeGetRank(MlirType type) { - return unwrap(type).cast().getRank(); + return llvm::cast(unwrap(type)).getRank(); } bool mlirShapedTypeHasStaticShape(MlirType type) { - return unwrap(type).cast().hasStaticShape(); + return llvm::cast(unwrap(type)).hasStaticShape(); } bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { - return unwrap(type).cast().isDynamicDim( - static_cast(dim)); + return llvm::cast(unwrap(type)) + .isDynamicDim(static_cast(dim)); } int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { - return unwrap(type).cast().getDimSize(static_cast(dim)); + return llvm::cast(unwrap(type)) + .getDimSize(static_cast(dim)); } int64_t mlirShapedTypeGetDynamicSize() { return ShapedType::kDynamic; } @@ -207,7 +214,9 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() { // Vector type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAVector(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAVector(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirVectorTypeGet(intptr_t rank, const int64_t *shape, MlirType elementType) { @@ -226,14 +235,16 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, // Ranked / Unranked tensor type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATensor(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATensor(MlirType type) { + return llvm::isa(unwrap(type)); +} bool mlirTypeIsARankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } bool mlirTypeIsAUnrankedTensor(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirRankedTensorTypeGet(intptr_t rank, const int64_t *shape, @@ -253,7 +264,7 @@ MlirType mlirRankedTensorTypeGetChecked(MlirLocation loc, intptr_t rank, } MlirAttribute mlirRankedTensorTypeGetEncoding(MlirType type) { - return wrap(unwrap(type).cast().getEncoding()); + return wrap(llvm::cast(unwrap(type)).getEncoding()); } MlirType mlirUnrankedTensorTypeGet(MlirType elementType) { @@ -269,7 +280,9 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAMemRef(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAMemRef(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, const int64_t *shape, MlirAttribute layout, @@ -278,7 +291,7 @@ MlirType mlirMemRefTypeGet(MlirType elementType, intptr_t rank, llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -291,7 +304,7 @@ MlirType mlirMemRefTypeGetChecked(MlirLocation loc, MlirType elementType, unwrap(elementType), mlirAttributeIsNull(layout) ? MemRefLayoutAttrInterface() - : unwrap(layout).cast(), + : llvm::cast(unwrap(layout)), unwrap(memorySpace))); } @@ -313,19 +326,19 @@ MlirType mlirMemRefTypeContiguousGetChecked(MlirLocation loc, } MlirAttribute mlirMemRefTypeGetLayout(MlirType type) { - return wrap(unwrap(type).cast().getLayout()); + return wrap(llvm::cast(unwrap(type)).getLayout()); } MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type) { - return wrap(unwrap(type).cast().getLayout().getAffineMap()); + return wrap(llvm::cast(unwrap(type)).getLayout().getAffineMap()); } MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } bool mlirTypeIsAUnrankedMemRef(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirUnrankedMemRefTypeGet(MlirType elementType, @@ -342,14 +355,16 @@ MlirType mlirUnrankedMemRefTypeGetChecked(MlirLocation loc, } MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { - return wrap(unwrap(type).cast().getMemorySpace()); + return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } //===----------------------------------------------------------------------===// // Tuple type. //===----------------------------------------------------------------------===// -bool mlirTypeIsATuple(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsATuple(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, MlirType const *elements) { @@ -359,11 +374,12 @@ MlirType mlirTupleTypeGet(MlirContext ctx, intptr_t numElements, } intptr_t mlirTupleTypeGetNumTypes(MlirType type) { - return unwrap(type).cast().size(); + return llvm::cast(unwrap(type)).size(); } MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { - return wrap(unwrap(type).cast().getType(static_cast(pos))); + return wrap( + llvm::cast(unwrap(type)).getType(static_cast(pos))); } //===----------------------------------------------------------------------===// @@ -371,7 +387,7 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { //===----------------------------------------------------------------------===// bool mlirTypeIsAFunction(MlirType type) { - return unwrap(type).isa(); + return llvm::isa(unwrap(type)); } MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, @@ -385,30 +401,32 @@ MlirType mlirFunctionTypeGet(MlirContext ctx, intptr_t numInputs, } intptr_t mlirFunctionTypeGetNumInputs(MlirType type) { - return unwrap(type).cast().getNumInputs(); + return llvm::cast(unwrap(type)).getNumInputs(); } intptr_t mlirFunctionTypeGetNumResults(MlirType type) { - return unwrap(type).cast().getNumResults(); + return llvm::cast(unwrap(type)).getNumResults(); } MlirType mlirFunctionTypeGetInput(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getInput(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getInput(static_cast(pos))); } MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { assert(pos >= 0 && "pos in array must be positive"); - return wrap( - unwrap(type).cast().getResult(static_cast(pos))); + return wrap(llvm::cast(unwrap(type)) + .getResult(static_cast(pos))); } //===----------------------------------------------------------------------===// // Opaque type. //===----------------------------------------------------------------------===// -bool mlirTypeIsAOpaque(MlirType type) { return unwrap(type).isa(); } +bool mlirTypeIsAOpaque(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, MlirStringRef typeData) { @@ -418,9 +436,10 @@ MlirType mlirOpaqueTypeGet(MlirContext ctx, MlirStringRef dialectNamespace, } MlirStringRef mlirOpaqueTypeGetDialectNamespace(MlirType type) { - return wrap(unwrap(type).cast().getDialectNamespace().strref()); + return wrap( + llvm::cast(unwrap(type)).getDialectNamespace().strref()); } MlirStringRef mlirOpaqueTypeGetData(MlirType type) { - return wrap(unwrap(type).cast().getTypeData()); + return wrap(llvm::cast(unwrap(type)).getTypeData()); } diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 79386dedf..c0cf59777 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -172,7 +172,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) { } MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { - return wrap(Location(unwrap(attribute).cast())); + return wrap(Location(llvm::cast(unwrap(attribute)))); } MlirLocation mlirLocationFileLineColGet(MlirContext context, @@ -727,33 +727,33 @@ bool mlirValueEqual(MlirValue value1, MlirValue value2) { } bool mlirValueIsABlockArgument(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } bool mlirValueIsAOpResult(MlirValue value) { - return unwrap(value).isa(); + return llvm::isa(unwrap(value)); } MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getArgNumber()); + llvm::cast(unwrap(value)).getArgNumber()); } void mlirBlockArgumentSetType(MlirValue value, MlirType type) { - unwrap(value).cast().setType(unwrap(type)); + llvm::cast(unwrap(value)).setType(unwrap(type)); } MlirOperation mlirOpResultGetOwner(MlirValue value) { - return wrap(unwrap(value).cast().getOwner()); + return wrap(llvm::cast(unwrap(value)).getOwner()); } intptr_t mlirOpResultGetResultNumber(MlirValue value) { return static_cast( - unwrap(value).cast().getResultNumber()); + llvm::cast(unwrap(value)).getResultNumber()); } MlirType mlirValueGetType(MlirValue value) { @@ -857,7 +857,7 @@ MlirContext mlirAttributeGetContext(MlirAttribute attribute) { MlirType mlirAttributeGetType(MlirAttribute attribute) { Attribute attr = unwrap(attribute); - if (auto typedAttr = attr.dyn_cast()) + if (auto typedAttr = llvm::dyn_cast(attr)) return wrap(typedAttr.getType()); return wrap(NoneType::get(attr.getContext())); } From 222cc4cd3d23dddce97f899eeab5b3c95f53b9d0 Mon Sep 17 00:00:00 2001 From: kon72 Date: Fri, 12 May 2023 16:42:12 +0200 Subject: [PATCH 461/915] [mlir][linalg] Add channel-first variants of convolution This change adds the following three operations and unit tests for them: - conv_3d_ncdhw_fcdhw - depthwise_conv_1d_ncw_cw - depthwise_conv_3d_ncdhw_cdhw Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D150054 --- .../linalg/opdsl/ops/core_named_ops.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 4402624c1..9c96868c1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -493,6 +493,33 @@ def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, TypeFn.cast_signed(U, KZp)) +@linalg_structured_op +def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( + U, K[D.f, D.c, D.kd, D.kh, D.kw]) + + @linalg_structured_op def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), @@ -513,6 +540,26 @@ def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, TypeFn.cast_signed(U, K[D.kw, D.ic]) +@linalg_structured_op +def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KW), + O=TensorDef(U, S.N, S.IC, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1])): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ic, D.ow] += \ + TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \ + TypeFn.cast_signed(U, K[D.ic, D.kw]) + + @linalg_structured_op def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), @@ -716,6 +763,41 @@ def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, U, K[D.kd, D.kh, D.kw, D.ic]) +@linalg_structured_op +def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), + O=TensorDef(U, + S.N, + S.IC, + S.OD, + S.OH, + S.OW, + output=True), + strides=IndexAttrDef(S.SD, + S.SH, + S.SW, + default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, + S.DH, + S.DW, + default=[1, 1, 1])): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( + U, K[D.ic, D.kd, D.kh, D.kw]) + + @linalg_structured_op def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, @@ -749,6 +831,7 @@ def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) + @linalg_structured_op def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), From 91d8a4e8248b8f60b0be4f783402041bba813d2b Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Wed, 17 May 2023 13:09:53 -0700 Subject: [PATCH 462/915] [mlir][sparse] Renaming the STEA field `dimLevelType` to `lvlTypes` This commit is part of the migration of towards the new STEA syntax/design. In particular, this commit includes the following changes: * Renaming compiler-internal functions/methods: * `SparseTensorEncodingAttr::{getDimLevelType => getLvlTypes}` * `Merger::{getDimLevelType => getLvlType}` (for consistency) * `sparse_tensor::{getDimLevelType => buildLevelType}` (to help reduce confusion vs actual getter methods) * Renaming external facets to match: * the STEA parser and printer * the C and Python bindings * PyTACO However, the actual renaming of the `DimLevelType` itself (along with all the "dlt" names) will be handled in a separate commit. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D150330 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 4 ++-- .../lib/Bindings/Python/DialectSparseTensor.cpp | 12 +++++------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 17 ++++++++--------- 3 files changed, 15 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 8a6763b6c..1ff6dc1b8 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -52,7 +52,7 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, - enum MlirSparseTensorDimLevelType const *dimLevelTypes, + enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, int crdWidth); @@ -62,7 +62,7 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); /// Returns a specified level-type of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl); +mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); /// Returns the dimension-ordering of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED MlirAffineMap diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 0e07f2563..0f0e67604 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -39,30 +39,28 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { mlirAttributeIsASparseTensorEncodingAttr) .def_classmethod( "get", - [](py::object cls, - std::vector dimLevelTypes, + [](py::object cls, std::vector lvlTypes, std::optional dimOrdering, std::optional higherOrdering, int posWidth, int crdWidth, MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( - context, dimLevelTypes.size(), dimLevelTypes.data(), + context, lvlTypes.size(), lvlTypes.data(), dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, posWidth, crdWidth)); }, - py::arg("cls"), py::arg("dim_level_types"), py::arg("dim_ordering"), + py::arg("cls"), py::arg("lvl_types"), py::arg("dim_ordering"), py::arg("higher_ordering"), py::arg("pos_width"), py::arg("crd_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( - "dim_level_types", + "lvl_types", [](MlirAttribute self) { const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); std::vector ret; ret.reserve(lvlRank); for (int l = 0; l < lvlRank; ++l) - ret.push_back( - mlirSparseTensorEncodingAttrGetDimLevelType(self, l)); + ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l)); return ret; }) .def_property_readonly( diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 795ce51ff..8569acf43 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -47,16 +47,15 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, - MlirSparseTensorDimLevelType const *dimLevelTypes, - MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, - int crdWidth) { - SmallVector cppDimLevelTypes; - cppDimLevelTypes.reserve(lvlRank); + MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimOrdering, + MlirAffineMap higherOrdering, int posWidth, int crdWidth) { + SmallVector cppLvlTypes; + cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) - cppDimLevelTypes.push_back(static_cast(dimLevelTypes[l])); + cppLvlTypes.push_back(static_cast(lvlTypes[l])); return wrap(SparseTensorEncodingAttr::get( - unwrap(ctx), cppDimLevelTypes, unwrap(dimOrdering), - unwrap(higherOrdering), posWidth, crdWidth)); + unwrap(ctx), cppLvlTypes, unwrap(dimOrdering), unwrap(higherOrdering), + posWidth, crdWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { @@ -73,7 +72,7 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { } MlirSparseTensorDimLevelType -mlirSparseTensorEncodingAttrGetDimLevelType(MlirAttribute attr, intptr_t lvl) { +mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { return static_cast( cast(unwrap(attr)).getLvlType(lvl)); } From 8d525d08cf30e6e58cf66ea5bc984607beb32018 Mon Sep 17 00:00:00 2001 From: "pengchao.hu" Date: Mon, 22 May 2023 18:35:33 +0200 Subject: [PATCH 463/915] [MLIR][python bindings] Add more basic AttrBuilder for _ops_gen.py files Add more attribute builders, such as "F32Attr", "F64Attr" and "F64ArrayAttr", which are useful to create operations by python bindings. For example, tosa.clamp in _tosa_ops_gen.py need 'F32Attr'. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150757 --- mlir/python/mlir/ir.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 714253426..be065d463 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -27,7 +27,7 @@ def _indexAttr(x, context): @register_attribute_builder("I16Attr") -def _i32Attr(x, context): +def _i16Attr(x, context): return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) @@ -41,6 +41,26 @@ def _i64Attr(x, context): return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) +@register_attribute_builder("SI16Attr") +def _si16Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) + + +@register_attribute_builder("SI32Attr") +def _si32Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) + + +@register_attribute_builder("F32Attr") +def _f32Attr(x, context): + return FloatAttr.get_f32(x, context=context) + + +@register_attribute_builder("F64Attr") +def _f64Attr(x, context): + return FloatAttr.get_f64(x, context=context) + + @register_attribute_builder("StrAttr") def _stringAttr(x, context): return StringAttr.get(x, context=context) @@ -61,11 +81,26 @@ def _arrayAttr(x, context): return ArrayAttr.get(x, context=context) +@register_attribute_builder("I32ArrayAttr") +def _i32ArrayAttr(x, context): + return ArrayAttr.get([_i32Attr(v, context) for v in x]) + + @register_attribute_builder("I64ArrayAttr") def _i64ArrayAttr(x, context): return ArrayAttr.get([_i64Attr(v, context) for v in x]) +@register_attribute_builder("F32ArrayAttr") +def _f32ArrayAttr(x, context): + return ArrayAttr.get([_f32Attr(v, context) for v in x]) + + +@register_attribute_builder("F64ArrayAttr") +def _f64ArrayAttr(x, context): + return ArrayAttr.get([_f64Attr(v, context) for v in x]) + + @register_attribute_builder("DenseI64ArrayAttr") def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) From 193d744af7e77fd3841ae9711dbdf3ee10531c21 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 22 May 2023 11:12:53 -0500 Subject: [PATCH 464/915] [MLIR][python bindings] Expose TypeIDs in python This diff adds python bindings for `MlirTypeID`. It paves the way for returning accurately typed `Type`s from python APIs (see D150927) and then further along building type "conscious" `Value` APIs (see D150413). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150839 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 21 ++++++ mlir/include/mlir-c/BuiltinTypes.h | 66 +++++++++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 21 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 70 ++++++++++++++++--- mlir/lib/Bindings/Python/IRModule.h | 50 ++++++++++++- mlir/lib/Bindings/Python/IRTypes.cpp | 42 +++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 62 ++++++++++++++++ mlir/python/mlir/dialects/python_test.py | 2 +- 8 files changed, 322 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index b877f94aa..6ebb45808 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -80,6 +80,8 @@ #define MLIR_PYTHON_CAPSULE_PASS_MANAGER \ MAKE_MLIR_PYTHON_QUALNAME("passmanager.PassManager._CAPIPtr") #define MLIR_PYTHON_CAPSULE_VALUE MAKE_MLIR_PYTHON_QUALNAME("ir.Value._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_TYPEID \ + MAKE_MLIR_PYTHON_QUALNAME("ir.TypeID._CAPIPtr") /** Attribute on MLIR Python objects that expose their C-API pointer. * This will be a type-specific capsule created as per one of the helpers @@ -268,6 +270,25 @@ static inline MlirOperation mlirPythonCapsuleToOperation(PyObject *capsule) { return op; } +/** Creates a capsule object encapsulating the raw C-API MlirTypeID. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the type in any way. + */ +static inline PyObject *mlirPythonTypeIDToCapsule(MlirTypeID typeID) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(typeID), + MLIR_PYTHON_CAPSULE_TYPEID, NULL); +} + +/** Extracts an MlirTypeID from a capsule as produced from + * mlirPythonTypeIDToCapsule. If the capsule is not of the right type, then + * a null type is returned (as checked via mlirTypeIDIsNull). In such a + * case, the Python APIs will have already set an error. */ +static inline MlirTypeID mlirPythonCapsuleToTypeID(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_TYPEID); + MlirTypeID typeID = {ptr}; + return typeID; +} + /** Creates a capsule object encapsulating the raw C-API MlirType. * The returned capsule does not extend or affect ownership of any Python * objects that reference the type in any way. diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 2b7606f3d..4348c5ba1 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -22,6 +22,9 @@ extern "C" { // Integer types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Integer type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerTypeGetTypeID(void); + /// Checks whether the given type is an integer type. MLIR_CAPI_EXPORTED bool mlirTypeIsAInteger(MlirType type); @@ -56,6 +59,9 @@ MLIR_CAPI_EXPORTED bool mlirIntegerTypeIsUnsigned(MlirType type); // Index type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Index type. +MLIR_CAPI_EXPORTED MlirTypeID mlirIndexTypeGetTypeID(void); + /// Checks whether the given type is an index type. MLIR_CAPI_EXPORTED bool mlirTypeIsAIndex(MlirType type); @@ -67,6 +73,9 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Float8E5M2 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); @@ -74,6 +83,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FN type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); @@ -81,6 +93,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FN(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E5M2FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E5M2FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); @@ -88,6 +103,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); @@ -95,6 +113,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3B11FNUZ type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID(void); + /// Checks whether the given type is an f8E4M3B11FNUZ type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); @@ -102,6 +123,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an BFloat16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); + /// Checks whether the given type is a bf16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); @@ -109,6 +133,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsABF16(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirBF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float16 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat16TypeGetTypeID(void); + /// Checks whether the given type is an f16 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); @@ -116,6 +143,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF16(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF16TypeGet(MlirContext ctx); +/// Returns the typeID of an Float32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat32TypeGetTypeID(void); + /// Checks whether the given type is an f32 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); @@ -123,6 +153,9 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF32(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF32TypeGet(MlirContext ctx); +/// Returns the typeID of an Float64 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat64TypeGetTypeID(void); + /// Checks whether the given type is an f64 type. MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); @@ -134,6 +167,9 @@ MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx); // None type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an None type. +MLIR_CAPI_EXPORTED MlirTypeID mlirNoneTypeGetTypeID(void); + /// Checks whether the given type is a None type. MLIR_CAPI_EXPORTED bool mlirTypeIsANone(MlirType type); @@ -145,6 +181,9 @@ MLIR_CAPI_EXPORTED MlirType mlirNoneTypeGet(MlirContext ctx); // Complex type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Complex type. +MLIR_CAPI_EXPORTED MlirTypeID mlirComplexTypeGetTypeID(void); + /// Checks whether the given type is a Complex type. MLIR_CAPI_EXPORTED bool mlirTypeIsAComplex(MlirType type); @@ -159,6 +198,9 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type); // Shaped type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Shaped type. +MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void); + /// Checks whether the given type is a Shaped type. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); @@ -202,6 +244,9 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); // Vector type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Vector type. +MLIR_CAPI_EXPORTED MlirTypeID mlirVectorTypeGetTypeID(void); + /// Checks whether the given type is a Vector type. MLIR_CAPI_EXPORTED bool mlirTypeIsAVector(MlirType type); @@ -226,9 +271,15 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, /// Checks whether the given type is a Tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsATensor(MlirType type); +/// Returns the typeID of an RankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirRankedTensorTypeGetTypeID(void); + /// Checks whether the given type is a ranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsARankedTensor(MlirType type); +/// Returns the typeID of an UnrankedTensor type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedTensorTypeGetTypeID(void); + /// Checks whether the given type is an unranked tensor type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedTensor(MlirType type); @@ -264,9 +315,15 @@ mlirUnrankedTensorTypeGetChecked(MlirLocation loc, MlirType elementType); // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an MemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirMemRefTypeGetTypeID(void); + /// Checks whether the given type is a MemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAMemRef(MlirType type); +/// Returns the typeID of an UnrankedMemRef type. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnrankedMemRefTypeGetTypeID(void); + /// Checks whether the given type is an UnrankedMemRef type. MLIR_CAPI_EXPORTED bool mlirTypeIsAUnrankedMemRef(MlirType type); @@ -326,6 +383,9 @@ mlirUnrankedMemrefGetMemorySpace(MlirType type); // Tuple type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Tuple type. +MLIR_CAPI_EXPORTED MlirTypeID mlirTupleTypeGetTypeID(void); + /// Checks whether the given type is a tuple type. MLIR_CAPI_EXPORTED bool mlirTypeIsATuple(MlirType type); @@ -345,6 +405,9 @@ MLIR_CAPI_EXPORTED MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos); // Function type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Function type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFunctionTypeGetTypeID(void); + /// Checks whether the given type is a function type. MLIR_CAPI_EXPORTED bool mlirTypeIsAFunction(MlirType type); @@ -373,6 +436,9 @@ MLIR_CAPI_EXPORTED MlirType mlirFunctionTypeGetResult(MlirType type, // Opaque type. //===----------------------------------------------------------------------===// +/// Returns the typeID of an Opaque type. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueTypeGetTypeID(void); + /// Checks whether the given type is an opaque type. MLIR_CAPI_EXPORTED bool mlirTypeIsAOpaque(MlirType type); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index bec3fc76e..ccca3aa01 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -236,6 +236,27 @@ struct type_caster { } }; +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirTypeID, _("MlirTypeID")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle cast(MlirTypeID v, return_value_policy, handle) { + if (v.ptr == nullptr) + return py::none(); + py::object capsule = + py::reinterpret_steal(mlirPythonTypeIDToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + /// Casts object <-> MlirType. template <> struct type_caster { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7ffa46400..db8390abe 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -17,6 +17,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -1807,6 +1808,24 @@ PyType PyType::createFromCapsule(py::object capsule) { rawType); } +//------------------------------------------------------------------------------ +// PyTypeID. +//------------------------------------------------------------------------------ + +py::object PyTypeID::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +} + +PyTypeID PyTypeID::createFromCapsule(py::object capsule) { + MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); + if (mlirTypeIDIsNull(mlirTypeID)) + throw py::error_already_set(); + return PyTypeID(mlirTypeID); +} +bool PyTypeID::operator==(const PyTypeID &other) const { + return mlirTypeIDEqual(typeID, other.typeID); +} + //------------------------------------------------------------------------------ // PyValue and subclases. //------------------------------------------------------------------------------ @@ -3268,16 +3287,47 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, "Returns the assembly form of the type.") - .def("__repr__", [](PyType &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, types are an exception as they typically have compact - // assembly forms and printing them is useful. - PyPrintAccumulator printAccum; - printAccum.parts.append("Type("); - mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyType &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, types are an exception as they typically have compact + // assembly forms and printing them is useful. + PyPrintAccumulator printAccum; + printAccum.parts.append("Type("); + mlirTypePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + if (!mlirTypeIDIsNull(mlirTypeID)) + return mlirTypeID; + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); + }); + + //---------------------------------------------------------------------------- + // Mapping of PyTypeID. + //---------------------------------------------------------------------------- + py::class_(m, "TypeID", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the Python objects are the same (i.e., PyTypeID is a value type). + .def("__eq__", + [](PyTypeID &self, PyTypeID &other) { return self == other; }) + .def("__eq__", + [](PyTypeID &self, const py::object &other) { return false; }) + // Note, this gives the hash value of the underlying TypeID, not the + // hash value of the Python object, nor the hash value of the + // MlirTypeID wrapper. + .def("__hash__", [](PyTypeID &self) { + return static_cast(mlirTypeIDHashValue(self)); }); //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index ade790ba0..fa529c434 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -20,6 +20,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -826,6 +827,29 @@ class PyType : public BaseContextObject { MlirType type; }; +/// A TypeID provides an efficient and unique identifier for a specific C++ +/// type. This allows for a C++ type to be compared, hashed, and stored in an +/// opaque context. This class wraps around the generic MlirTypeID. +class PyTypeID { +public: + PyTypeID(MlirTypeID typeID) : typeID(typeID) {} + // Note, this tests whether the underlying TypeIDs are the same, + // not whether the wrapper MlirTypeIDs are the same, nor whether + // the PyTypeID objects are the same (i.e., PyTypeID is a value type). + bool operator==(const PyTypeID &other) const; + operator MlirTypeID() const { return typeID; } + MlirTypeID get() { return typeID; } + + /// Gets a capsule wrapping the void* within the MlirTypeID. + pybind11::object getCapsule(); + + /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. + static PyTypeID createFromCapsule(pybind11::object capsule); + +private: + MlirTypeID typeID; +}; + /// CRTP base classes for Python types that subclass Type and should be /// castable from it (i.e. via something like IntegerType(t)). /// By default, type class hierarchies are one level deep (i.e. a @@ -839,10 +863,14 @@ class PyConcreteType : public BaseTy { // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) {} + : BaseTy(std::move(contextRef), t) { + pybind11::implicitly_convertible(); + } PyConcreteType(PyType &orig) : PyConcreteType(orig.getContext(), castFrom(orig)) {} @@ -866,6 +894,26 @@ class PyConcreteType : public BaseTy { return DerivedTy::isaFunction(otherType); }, pybind11::arg("other")); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw SetPyError(PyExc_AttributeError, + DerivedTy::pyClassName + + llvm::Twine(" has no typeid.")); + }); + cls.def_property_readonly("typeid", [](PyType &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirTypePrint(self, printAccum.getCallback(), printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index cb62a402d..f45b30250 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -32,6 +32,8 @@ static int mlirTypeIsAIntegerOrFloat(MlirType type) { class PyIntegerType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; static constexpr const char *pyClassName = "IntegerType"; using PyConcreteType::PyConcreteType; @@ -89,6 +91,8 @@ class PyIntegerType : public PyConcreteType { class PyIndexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; static constexpr const char *pyClassName = "IndexType"; using PyConcreteType::PyConcreteType; @@ -107,6 +111,8 @@ class PyIndexType : public PyConcreteType { class PyFloat8E4M3FNType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNType"; using PyConcreteType::PyConcreteType; @@ -125,6 +131,8 @@ class PyFloat8E4M3FNType : public PyConcreteType { class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2Type"; using PyConcreteType::PyConcreteType; @@ -143,6 +151,8 @@ class PyFloat8E5M2Type : public PyConcreteType { class PyFloat8E4M3FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3FNUZType"; using PyConcreteType::PyConcreteType; @@ -161,6 +171,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { class PyFloat8E4M3B11FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; using PyConcreteType::PyConcreteType; @@ -179,6 +191,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { class PyFloat8E5M2FNUZType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; static constexpr const char *pyClassName = "Float8E5M2FNUZType"; using PyConcreteType::PyConcreteType; @@ -197,6 +211,8 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; static constexpr const char *pyClassName = "BF16Type"; using PyConcreteType::PyConcreteType; @@ -215,6 +231,8 @@ class PyBF16Type : public PyConcreteType { class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; static constexpr const char *pyClassName = "F16Type"; using PyConcreteType::PyConcreteType; @@ -233,6 +251,8 @@ class PyF16Type : public PyConcreteType { class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; static constexpr const char *pyClassName = "F32Type"; using PyConcreteType::PyConcreteType; @@ -251,6 +271,8 @@ class PyF32Type : public PyConcreteType { class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; static constexpr const char *pyClassName = "F64Type"; using PyConcreteType::PyConcreteType; @@ -269,6 +291,8 @@ class PyF64Type : public PyConcreteType { class PyNoneType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; static constexpr const char *pyClassName = "NoneType"; using PyConcreteType::PyConcreteType; @@ -287,6 +311,8 @@ class PyNoneType : public PyConcreteType { class PyComplexType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; static constexpr const char *pyClassName = "ComplexType"; using PyConcreteType::PyConcreteType; @@ -417,6 +443,8 @@ class PyShapedType : public PyConcreteType { class PyVectorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; static constexpr const char *pyClassName = "VectorType"; using PyConcreteType::PyConcreteType; @@ -442,6 +470,8 @@ class PyRankedTensorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "RankedTensorType"; using PyConcreteType::PyConcreteType; @@ -476,6 +506,8 @@ class PyUnrankedTensorType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedTensorType"; using PyConcreteType::PyConcreteType; @@ -498,6 +530,8 @@ class PyUnrankedTensorType class PyMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; static constexpr const char *pyClassName = "MemRefType"; using PyConcreteType::PyConcreteType; @@ -550,6 +584,8 @@ class PyUnrankedMemRefType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; static constexpr const char *pyClassName = "UnrankedMemRefType"; using PyConcreteType::PyConcreteType; @@ -585,6 +621,8 @@ class PyUnrankedMemRefType class PyTupleType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; static constexpr const char *pyClassName = "TupleType"; using PyConcreteType::PyConcreteType; @@ -622,6 +660,8 @@ class PyTupleType : public PyConcreteType { class PyFunctionType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; static constexpr const char *pyClassName = "FunctionType"; using PyConcreteType::PyConcreteType; @@ -676,6 +716,8 @@ static MlirStringRef toMlirStringRef(const std::string &s) { class PyOpaqueType : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; static constexpr const char *pyClassName = "OpaqueType"; using PyConcreteType::PyConcreteType; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 90ab84760..1925478c6 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -22,6 +22,8 @@ using namespace mlir; // Integer types. //===----------------------------------------------------------------------===// +MlirTypeID mlirIntegerTypeGetTypeID() { return wrap(IntegerType::getTypeID()); } + bool mlirTypeIsAInteger(MlirType type) { return llvm::isa(unwrap(type)); } @@ -58,6 +60,8 @@ bool mlirIntegerTypeIsUnsigned(MlirType type) { // Index type. //===----------------------------------------------------------------------===// +MlirTypeID mlirIndexTypeGetTypeID() { return wrap(IndexType::getTypeID()); } + bool mlirTypeIsAIndex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -70,6 +74,10 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// +MlirTypeID mlirFloat8E5M2TypeGetTypeID() { + return wrap(Float8E5M2Type::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2(MlirType type) { return unwrap(type).isFloat8E5M2(); } @@ -78,6 +86,10 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { + return wrap(Float8E4M3FNType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FN(MlirType type) { return unwrap(type).isFloat8E4M3FN(); } @@ -86,6 +98,10 @@ MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); } +MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { + return wrap(Float8E5M2FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { return unwrap(type).isFloat8E5M2FNUZ(); } @@ -94,6 +110,10 @@ MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { + return wrap(Float8E4M3FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3FNUZ(); } @@ -102,6 +122,10 @@ MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { + return wrap(Float8E4M3B11FNUZType::getTypeID()); +} + bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { return unwrap(type).isFloat8E4M3B11FNUZ(); } @@ -110,24 +134,34 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +MlirTypeID mlirBFloat16TypeGetTypeID() { + return wrap(BFloat16Type::getTypeID()); +} + bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(FloatType::getBF16(unwrap(ctx))); } +MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } + bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } + bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(FloatType::getF32(unwrap(ctx))); } +MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } + bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { @@ -138,6 +172,8 @@ MlirType mlirF64TypeGet(MlirContext ctx) { // None type. //===----------------------------------------------------------------------===// +MlirTypeID mlirNoneTypeGetTypeID() { return wrap(NoneType::getTypeID()); } + bool mlirTypeIsANone(MlirType type) { return llvm::isa(unwrap(type)); } @@ -150,6 +186,8 @@ MlirType mlirNoneTypeGet(MlirContext ctx) { // Complex type. //===----------------------------------------------------------------------===// +MlirTypeID mlirComplexTypeGetTypeID() { return wrap(ComplexType::getTypeID()); } + bool mlirTypeIsAComplex(MlirType type) { return llvm::isa(unwrap(type)); } @@ -214,6 +252,8 @@ int64_t mlirShapedTypeGetDynamicStrideOrOffset() { // Vector type. //===----------------------------------------------------------------------===// +MlirTypeID mlirVectorTypeGetTypeID() { return wrap(VectorType::getTypeID()); } + bool mlirTypeIsAVector(MlirType type) { return llvm::isa(unwrap(type)); } @@ -239,10 +279,18 @@ bool mlirTypeIsATensor(MlirType type) { return llvm::isa(unwrap(type)); } +MlirTypeID mlirRankedTensorTypeGetTypeID() { + return wrap(RankedTensorType::getTypeID()); +} + bool mlirTypeIsARankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } +MlirTypeID mlirUnrankedTensorTypeGetTypeID() { + return wrap(UnrankedTensorType::getTypeID()); +} + bool mlirTypeIsAUnrankedTensor(MlirType type) { return llvm::isa(unwrap(type)); } @@ -280,6 +328,8 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// +MlirTypeID mlirMemRefTypeGetTypeID() { return wrap(MemRefType::getTypeID()); } + bool mlirTypeIsAMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -337,6 +387,10 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { + return wrap(UnrankedMemRefType::getTypeID()); +} + bool mlirTypeIsAUnrankedMemRef(MlirType type) { return llvm::isa(unwrap(type)); } @@ -362,6 +416,8 @@ MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type) { // Tuple type. //===----------------------------------------------------------------------===// +MlirTypeID mlirTupleTypeGetTypeID() { return wrap(TupleType::getTypeID()); } + bool mlirTypeIsATuple(MlirType type) { return llvm::isa(unwrap(type)); } @@ -386,6 +442,10 @@ MlirType mlirTupleTypeGetType(MlirType type, intptr_t pos) { // Function type. //===----------------------------------------------------------------------===// +MlirTypeID mlirFunctionTypeGetTypeID() { + return wrap(FunctionType::getTypeID()); +} + bool mlirTypeIsAFunction(MlirType type) { return llvm::isa(unwrap(type)); } @@ -424,6 +484,8 @@ MlirType mlirFunctionTypeGetResult(MlirType type, intptr_t pos) { // Opaque type. //===----------------------------------------------------------------------===// +MlirTypeID mlirOpaqueTypeGetTypeID() { return wrap(OpaqueType::getTypeID()); } + bool mlirTypeIsAOpaque(MlirType type) { return llvm::isa(unwrap(type)); } diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 5d42ddc47..ca0d479f1 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType def register_python_test_dialect(context, load=True): from .._mlir_libs import _mlirPythonTest From 7dadb45343b9d8d236fe03e3e07ba568f8187e0d Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 23 May 2023 08:52:53 +0200 Subject: [PATCH 465/915] [mlir] Apply ClangTidy performance finding (NFC) --- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index e597a7bcb..3144a338f 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -152,7 +152,7 @@ MlirLogicalResult mlirInferShapedTypeOpInterfaceInferReturnTypes( bool hasRank; intptr_t rank; const int64_t *shapeData; - for (ShapedTypeComponents t : inferredTypeComponents) { + for (const ShapedTypeComponents &t : inferredTypeComponents) { if (t.hasRank()) { hasRank = true; rank = t.getDims().size(); From 7b5a71e9f695db01743fbf37263071aee29a88b0 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 22 May 2023 17:30:12 -0500 Subject: [PATCH 466/915] [MLIR][python bindings] use pybind C++ APIs for throwing python errors. Differential Revision: https://reviews.llvm.org/D151167 --- mlir/lib/Bindings/Python/IRAffine.cpp | 7 +- mlir/lib/Bindings/Python/IRAttributes.cpp | 27 +++---- mlir/lib/Bindings/Python/IRCore.cpp | 99 +++++++++++------------ mlir/lib/Bindings/Python/IRModule.cpp | 14 ++-- mlir/lib/Bindings/Python/IRModule.h | 19 ++--- mlir/lib/Bindings/Python/IRTypes.cpp | 13 ++- mlir/lib/Bindings/Python/Pass.cpp | 4 +- mlir/lib/Bindings/Python/PybindUtils.cpp | 16 ---- mlir/lib/Bindings/Python/PybindUtils.h | 7 -- mlir/python/CMakeLists.txt | 1 - 10 files changed, 86 insertions(+), 121 deletions(-) delete mode 100644 mlir/lib/Bindings/Python/PybindUtils.cpp diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 9a2ea6b68..75f86a49e 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -93,9 +93,10 @@ class PyConcreteAffineExpr : public BaseTy { static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - Twine("Cannot cast affine expression to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw py::value_error((Twine("Cannot cast affine expression to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 22001957f..0ab47cc24 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// -#include #include +#include #include "IRModule.h" @@ -666,14 +666,14 @@ class PyDenseElementsAttribute !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; message.append(py::repr(py::cast(shapedType))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); @@ -683,7 +683,7 @@ class PyDenseElementsAttribute message.append(py::repr(py::cast(shapedType))); message.append(", element="); message.append(py::repr(py::cast(elementAttr))); - throw SetPyError(PyExc_ValueError, message); + throw py::value_error(message); } MlirAttribute elements = @@ -783,8 +783,7 @@ class PyDenseElementsAttribute .def("get_splat_value", [](PyDenseElementsAttribute &self) -> PyAttribute { if (!mlirDenseElementsAttrIsSplat(self)) { - throw SetPyError( - PyExc_ValueError, + throw py::value_error( "get_splat_value called on a non-splat attribute"); } return PyAttribute(self.getContext(), @@ -861,8 +860,7 @@ class PyDenseIntElementsAttribute /// out of range. py::int_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -909,7 +907,7 @@ class PyDenseIntElementsAttribute return mlirDenseElementsAttrGetInt64Value(*this, pos); } } - throw SetPyError(PyExc_TypeError, "Unsupported integer type"); + throw py::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -957,15 +955,13 @@ class PyDictAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); } return PyAttribute(self.getContext(), attr); }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -987,8 +983,7 @@ class PyDenseFPElementsAttribute py::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1004,7 +999,7 @@ class PyDenseFPElementsAttribute if (mlirTypeIsAF64(type)) { return mlirDenseElementsAttrGetDoubleValue(*this, pos); } - throw SetPyError(PyExc_TypeError, "Unsupported floating-point type"); + throw py::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index db8390abe..27c30683d 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -278,8 +278,7 @@ class PyRegionList { PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds region"); + throw py::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); @@ -351,8 +350,7 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -362,7 +360,7 @@ class PyBlockList { block = mlirBlockGetNextInRegion(block); index -= 1; } - throw SetPyError(PyExc_IndexError, "attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } PyBlock appendBlock(const py::args &pyArgTypes) { @@ -456,8 +454,7 @@ class PyOperationList { py::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -468,8 +465,7 @@ class PyOperationList { childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } static void bind(py::module &m) { @@ -684,8 +680,7 @@ MlirLogicalResult PyMlirContext::ErrorCapture::handler(MlirDiagnostic diag, PyMlirContext &DefaultingPyMlirContext::resolve() { PyMlirContext *context = PyThreadContextEntry::getDefaultContext(); if (!context) { - throw SetPyError( - PyExc_RuntimeError, + throw std::runtime_error( "An MLIR function requires a Context but none was provided in the call " "or from the surrounding environment. Either pass to the function with " "a 'context=' argument or establish a default using 'with Context():'"); @@ -775,10 +770,10 @@ py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { void PyThreadContextEntry::popContext(PyMlirContext &context) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + throw std::runtime_error("Unbalanced Context enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::Context && tos.getContext() != &context) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Context enter/exit"); + throw std::runtime_error("Unbalanced Context enter/exit"); stack.pop_back(); } @@ -797,13 +792,11 @@ PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::InsertionPoint && tos.getInsertionPoint() != &insertionPoint) - throw SetPyError(PyExc_RuntimeError, - "Unbalanced InsertionPoint enter/exit"); + throw std::runtime_error("Unbalanced InsertionPoint enter/exit"); stack.pop_back(); } @@ -819,10 +812,10 @@ py::object PyThreadContextEntry::pushLocation(PyLocation &location) { void PyThreadContextEntry::popLocation(PyLocation &location) { auto &stack = getStack(); if (stack.empty()) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + throw std::runtime_error("Unbalanced Location enter/exit"); auto &tos = stack.back(); if (tos.frameKind != FrameKind::Location && tos.getLocation() != &location) - throw SetPyError(PyExc_RuntimeError, "Unbalanced Location enter/exit"); + throw std::runtime_error("Unbalanced Location enter/exit"); stack.pop_back(); } @@ -913,8 +906,11 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, MlirDialect dialect = mlirContextGetOrLoadDialect(getContext()->get(), {key.data(), key.size()}); if (mlirDialectIsNull(dialect)) { - throw SetPyError(attrError ? PyExc_AttributeError : PyExc_IndexError, - Twine("Dialect '") + key + "' not found"); + std::string msg = (Twine("Dialect '") + key + "' not found").str(); + if (attrError) + throw py::attribute_error(msg); + else + throw py::index_error(msg); } return dialect; } @@ -961,8 +957,7 @@ void PyLocation::contextExit(const pybind11::object &excType, PyLocation &DefaultingPyLocation::resolve() { auto *location = PyThreadContextEntry::getDefaultLocation(); if (!location) { - throw SetPyError( - PyExc_RuntimeError, + throw std::runtime_error( "An MLIR function requires a Location but none was provided in the " "call or from the surrounding environment. Either pass to the function " "with a 'loc=' argument or establish a default using 'with loc:'"); @@ -1107,7 +1102,7 @@ PyOperationRef PyOperation::parse(PyMlirContextRef contextRef, void PyOperation::checkValid() const { if (!valid) { - throw SetPyError(PyExc_RuntimeError, "the operation has been invalidated"); + throw std::runtime_error("the operation has been invalidated"); } } @@ -1211,7 +1206,7 @@ bool PyOperationBase::verify() { std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw SetPyError(PyExc_ValueError, "Detached operations have no parent"); + throw py::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1270,14 +1265,14 @@ py::object PyOperation::create(const std::string &name, // General parameter validation. if (regions < 0) - throw SetPyError(PyExc_ValueError, "number of regions must be >= 0"); + throw py::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw SetPyError(PyExc_ValueError, "operand value cannot be None"); + throw py::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1288,7 +1283,7 @@ py::object PyOperation::create(const std::string &name, for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw SetPyError(PyExc_ValueError, "result type cannot be None"); + throw py::value_error("result type cannot be None"); mlirResults.push_back(*result); } } @@ -1329,7 +1324,7 @@ py::object PyOperation::create(const std::string &name, for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw SetPyError(PyExc_ValueError, "successor block cannot be None"); + throw py::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1701,8 +1696,8 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw SetPyError(PyExc_ValueError, - "Attempt to insert operation that is already attached"); + throw py::value_error( + "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; if (refOperation) { @@ -1740,7 +1735,7 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw SetPyError(PyExc_ValueError, "Block has no terminator"); + throw py::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; @@ -2033,9 +2028,10 @@ class PyConcreteValue : public PyValue { static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { auto origRepr = py::repr(py::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, Twine("Cannot cast value to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); + throw py::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig.get(); } @@ -2273,16 +2269,14 @@ class PyOpAttributeMap { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw SetPyError(PyExc_KeyError, - "attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); } return PyAttribute(operation->getContext(), attr); } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw SetPyError(PyExc_IndexError, - "attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2301,8 +2295,7 @@ class PyOpAttributeMap { int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw SetPyError(PyExc_KeyError, - "attempt to delete a non-existent attribute"); + throw py::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2402,7 +2395,7 @@ void mlir::python::populateIRCore(py::module &m) { [](py::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - throw SetPyError(PyExc_ValueError, "No current Context"); + throw py::value_error("No current Context"); return context; }, "Gets the Context bound to the current thread or raises ValueError") @@ -2419,8 +2412,8 @@ void mlir::python::populateIRCore(py::module &m) { MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw SetPyError(PyExc_ValueError, - Twine("Dialect '") + name + "' not found"); + throw py::value_error( + (Twine("Dialect '") + name + "' not found").str()); } return PyDialectDescriptor(self.getRef(), dialect); }, @@ -2545,7 +2538,7 @@ void mlir::python::populateIRCore(py::module &m) { [](py::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw SetPyError(PyExc_ValueError, "No current Location"); + throw py::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2752,13 +2745,13 @@ void mlir::python::populateIRCore(py::module &m) { auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw SetPyError( - PyExc_ValueError, - Twine("Cannot call .result on operation ") + - StringRef(name.data, name.length) + " which has " + - Twine(numResults) + - " results (it is only valid for operations with a " - "single result)"); + throw py::value_error( + (Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)") + .str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)); @@ -3119,7 +3112,7 @@ void mlir::python::populateIRCore(py::module &m) { [](py::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw SetPyError(PyExc_ValueError, "No current InsertionPoint"); + throw py::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 7221442e4..7c49f20f1 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -10,8 +10,8 @@ #include "Globals.h" #include "PybindUtils.h" -#include #include +#include #include "mlir-c/Bindings/Python/Interop.h" @@ -76,9 +76,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Dialect namespace '") + - dialectNamespace + - "' is already registered."); + throw std::runtime_error((llvm::Twine("Dialect namespace '") + + dialectNamespace + "' is already registered.") + .str()); } found = std::move(pyClass); } @@ -87,9 +87,9 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, py::object pyClass) { py::object &found = operationClassMap[operationName]; if (found) { - throw SetPyError(PyExc_RuntimeError, llvm::Twine("Operation '") + - operationName + - "' is already registered."); + throw std::runtime_error((llvm::Twine("Operation '") + operationName + + "' is already registered.") + .str()); } found = std::move(pyClass); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fa529c434..cfa3737cf 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -877,9 +877,10 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + - " (from " + origRepr + ")"); + throw py::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } @@ -898,9 +899,8 @@ class PyConcreteType : public BaseTy { "static_typeid", [](py::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw SetPyError(PyExc_AttributeError, - DerivedTy::pyClassName + - llvm::Twine(" has no typeid.")); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); }); cls.def_property_readonly("typeid", [](PyType &self) { return py::cast(self).attr("typeid").cast(); @@ -990,9 +990,10 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw SetPyError(PyExc_ValueError, - llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + ")"); + throw py::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index f45b30250..5c089b2f2 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -325,11 +325,11 @@ class PyComplexType : public PyConcreteType { MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw SetPyError( - PyExc_ValueError, - Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + - "' and expected floating point or integer type."); + throw py::value_error( + (Twine("invalid '") + + py::repr(py::cast(elementType)).cast() + + "' and expected floating point or integer type.") + .str()); }, "Create a complex type"); c.def_property_readonly( @@ -432,8 +432,7 @@ class PyShapedType : public PyConcreteType { private: void requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw SetPyError( - PyExc_ValueError, + throw py::value_error( "calling this method requires that the type has a rank."); } } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 79c53084e..cdbfcfbc2 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -93,7 +93,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, py::arg("pipeline"), py::arg("context") = py::none(), @@ -109,7 +109,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw SetPyError(PyExc_ValueError, std::string(errorMsg.join())); + throw py::value_error(std::string(errorMsg.join())); }, py::arg("pipeline"), "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/PybindUtils.cpp b/mlir/lib/Bindings/Python/PybindUtils.cpp deleted file mode 100644 index d243307f1..000000000 --- a/mlir/lib/Bindings/Python/PybindUtils.cpp +++ /dev/null @@ -1,16 +0,0 @@ -//===- PybindUtils.cpp - Utilities for interop with pybind11 --------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "PybindUtils.h" - -pybind11::error_already_set -mlir::python::SetPyError(PyObject *excClass, const llvm::Twine &message) { - auto messageStr = message.str(); - PyErr_SetString(excClass, messageStr.c_str()); - return pybind11::error_already_set(); -} diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 2d8bbc14c..41de7e9b4 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -19,13 +19,6 @@ namespace mlir { namespace python { -// Sets a python error, ready to be thrown to return control back to the -// python runtime. -// Correct usage: -// throw SetPyError(PyExc_ValueError, "Foobar'd"); -pybind11::error_already_set SetPyError(PyObject *excClass, - const llvm::Twine &message); - /// CRTP template for special wrapper types that are allowed to be passed in as /// 'None' function arguments and can be resolved by some global mechanic if /// so. Such types will raise an error if this global resolution fails, and diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 0a4c2f803..b0b4ed94f 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -263,7 +263,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRInterfaces.cpp IRModule.cpp IRTypes.cpp - PybindUtils.cpp Pass.cpp # Headers must be included explicitly so they are installed. From 4f2e6126ee2e96501e9b09f4da72161ca58403fc Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 23 May 2023 13:40:00 -0400 Subject: [PATCH 467/915] [mlir][python] Bump min pybind11 version to 2.9.0 2.9.0 was released on December 28, 2021, and some following changes require at least this version. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D150247 --- mlir/python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 16a3dd39c..7671d3329 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.19.5, <=1.23.5 -pybind11>=2.8.0, <=2.10.3 +pybind11>=2.9.0, <=2.10.3 PyYAML>= 5.3.1, <=6.0 dataclasses>=0.6, <=0.8 From e04fcff27e39084665ef5ed77c654cfdd4d22eef Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 22 May 2023 14:36:58 +0000 Subject: [PATCH 468/915] [mlir] move PDL-related transform ops into an extension The initial bring-up of the Transform dialect relied on PDL to provide the default handle type (`!pdl.operation`) and the matching capability. Both are now provided natively by the Transform dialect removing the reason to have a hard dependency on the PDL dialect and its interpreter. Move PDL-related transform operations into a separate extension. This requires us to introduce a dialect state extension mechanism into the Transform dialect so it no longer needs to know about PDL constraint functions that may be injected by extensions similarly to operations and types. This mechanism will be reused to connect pattern application drivers and the Transform dialect. This completes the restructuring of the Transform dialect to remove overrilance on PDL. Note to downstreams: flow that are using `!pdl.operation` with Transform dialect operations will now require `transform::PDLExtension` to be applied to the transform dialect in order to provide the transform handle type interface for `!pdl.operation`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D151104 --- mlir/python/CMakeLists.txt | 10 ++++ .../mlir/dialects/TransformPDLExtensionOps.td | 20 +++++++ .../mlir/dialects/_transform_ops_ext.py | 42 -------------- .../_transform_pdl_extension_ops_ext.py | 55 +++++++++++++++++++ mlir/python/mlir/dialects/transform/pdl.py | 5 ++ 5 files changed, 90 insertions(+), 42 deletions(-) create mode 100644 mlir/python/mlir/dialects/TransformPDLExtensionOps.td create mode 100644 mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py create mode 100644 mlir/python/mlir/dialects/transform/pdl.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b0b4ed94f..39dd7b006 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -114,6 +114,16 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME linalg DEPENDS LinalgOdsGen) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformPDLExtensionOps.td + SOURCES + dialects/_transform_pdl_extension_ops_ext.py + dialects/transform/pdl.py + DIALECT_NAME transform + EXTENSION_NAME transform_pdl_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td new file mode 100644 index 000000000..e3e5daf18 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td @@ -0,0 +1,20 @@ +//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the PDL extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 8651c76ea..cc4428ea5 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -60,26 +60,6 @@ def __init__( ) -class PDLMatchOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - pattern_name, - loc=loc, - ip=ip, - ) - - class ReplicateOp: def __init__( @@ -152,28 +132,6 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] -class WithPDLPatternsOp: - - def __init__(self, - target: Union[Operation, Value, Type], - *, - loc=None, - ip=None): - root = _get_op_result_or_value(target) if not isinstance(target, - Type) else None - root_type = target if isinstance(target, Type) else root.type - super().__init__(root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(root_type) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - class YieldOp: def __init__( diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py new file mode 100644 index 000000000..c4e4b4b42 --- /dev/null +++ b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py @@ -0,0 +1,55 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + +class PDLMatchOp: + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +class WithPDLPatternsOp: + + def __init__(self, + target: Union[Operation, Value, Type], + *, + loc=None, + ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py new file mode 100644 index 000000000..b1515287a --- /dev/null +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_pdl_extension_ops_gen import * From c2fd8f65078c0f5094966388c16cdb171c579a31 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 24 May 2023 17:48:54 -0500 Subject: [PATCH 469/915] [mlir][capi] Drop mlirShapedTypeGetTypeID ShapedType is a virtual type rather than a concrete one. We don't have an implmentation for this API too. Reviewed By: makslevental Differential Revision: https://reviews.llvm.org/D151376 --- mlir/include/mlir-c/BuiltinTypes.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 4348c5ba1..c8ea44cd9 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -198,9 +198,6 @@ MLIR_CAPI_EXPORTED MlirType mlirComplexTypeGetElementType(MlirType type); // Shaped type. //===----------------------------------------------------------------------===// -/// Returns the typeID of an Shaped type. -MLIR_CAPI_EXPORTED MlirTypeID mlirShapedTypeGetTypeID(void); - /// Checks whether the given type is a Shaped type. MLIR_CAPI_EXPORTED bool mlirTypeIsAShaped(MlirType type); From 84d968b1d15683a41d16551caeb823c844346636 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 24 May 2023 21:51:36 -0400 Subject: [PATCH 470/915] [mlir][python] Allow specifying block arg locations Currently blocks are always created with UnknownLoc's for their arguments. This adds an `arg_locs` argument to all block creation APIs, which takes an optional sequence of locations to use, one per block argument. If no locations are supplied, the current Location context is used. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150084 --- mlir/lib/Bindings/Python/IRCore.cpp | 105 +++++++++------------ mlir/python/mlir/dialects/_func_ops_ext.py | 4 +- 2 files changed, 47 insertions(+), 62 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 27c30683d..7013cca53 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -194,6 +194,31 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +/// Create a block, using the current location context if no locations are +/// specified. +static MlirBlock createBlock(const py::sequence &pyArgTypes, + const std::optional &pyArgLocs) { + SmallVector argTypes; + argTypes.reserve(pyArgTypes.size()); + for (const auto &pyType : pyArgTypes) + argTypes.push_back(pyType.cast()); + + SmallVector argLocs; + if (pyArgLocs) { + argLocs.reserve(pyArgLocs->size()); + for (const auto &pyLoc : *pyArgLocs) + argLocs.push_back(pyLoc.cast()); + } else if (!argTypes.empty()) { + argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); + } + + if (argTypes.size() != argLocs.size()) + throw py::value_error(("Expected " + Twine(argTypes.size()) + + " locations, got: " + Twine(argLocs.size())) + .str()); + return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); +} + /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } @@ -363,21 +388,10 @@ class PyBlockList { throw py::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes) { + PyBlock appendBlock(const py::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = - mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } @@ -387,7 +401,8 @@ class PyBlockList { .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) - .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring); + .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, + py::arg("arg_locs") = std::nullopt); } private: @@ -2978,27 +2993,17 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, py::list pyArgTypes) { + [](PyRegion &parent, const py::list &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, py::arg("parent"), py::arg("arg_types") = py::list(), + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " - "region (with given argument types).") + "region (with given argument types and locations).") .def( "append_to", [](PyBlock &self, PyRegion ®ion) { @@ -3010,50 +3015,30 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, py::args pyArgTypes) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - llvm::SmallVector argTypes; - llvm::SmallVector argLocs; - argTypes.reserve(pyArgTypes.size()); - argLocs.reserve(pyArgTypes.size()); - for (auto &pyArg : pyArgTypes) { - argTypes.push_back(pyArg.cast()); - - // TODO: Pass in a proper location here. - argLocs.push_back( - mlirLocationUnknownGet(mlirTypeGetContext(argTypes.back()))); - } - MlirBlock block = mlirBlockCreate(argTypes.size(), argTypes.data(), - argLocs.data()); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " - "(with given argument types).") + "(with given argument types and locations).") .def( "__iter__", [](PyBlock &self) { diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 79577463d..56df423d3 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -90,7 +90,7 @@ def entry_block(self): raise IndexError('External function does not have a body') return self.regions[0].blocks[0] - def add_entry_block(self): + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): """ Add an entry block to the function body using the function signature to infer block arguments. @@ -98,7 +98,7 @@ def add_entry_block(self): """ if not self.is_external: raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) return self.body.blocks[0] @property From 48783db0bb979d6ad42a9cb2c0ad517e1e8ba7b3 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 24 May 2023 22:05:06 -0400 Subject: [PATCH 471/915] [mlir][python] Hook up PyRegionList.__iter__ to PyRegionIterator This fixes a -Wunused-member-function warning, at the moment `PyRegionIterator` is never constructed by anything (the only use was removed in D111697), and iterating over region lists is just falling back to a generic python iterator object. Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D150244 --- mlir/lib/Bindings/Python/IRCore.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7013cca53..a6bd4d849 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -295,6 +295,11 @@ class PyRegionList { public: PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} + PyRegionIterator dunderIter() { + operation->checkValid(); + return PyRegionIterator(operation); + } + intptr_t dunderLen() { operation->checkValid(); return mlirOperationGetNumRegions(operation->get()); @@ -312,6 +317,7 @@ class PyRegionList { static void bind(py::module &m) { py::class_(m, "RegionSequence", py::module_local()) .def("__len__", &PyRegionList::dunderLen) + .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); } From 35e2cbfab12d0ac577a2ce20ab9937dc81beaeed Mon Sep 17 00:00:00 2001 From: Tobias Hieta Date: Wed, 17 May 2023 16:53:39 +0200 Subject: [PATCH 472/915] [NFC][Py Reformat] Reformat python files in mlir subdir This is an ongoing series of commits that are reformatting our Python code. Reformatting is done with `black`. If you end up having problems merging this commit because you have made changes to a python file, the best way to handle that is to run git checkout --ours and then reformat it with black. If you run into any problems, post to discourse about it and we will try to help. RFC Thread below: https://discourse.llvm.org/t/rfc-document-and-standardize-python-code-style Differential Revision: https://reviews.llvm.org/D150782 --- mlir/python/mlir/_mlir_libs/__init__.py | 189 +- mlir/python/mlir/dialects/_arith_ops_ext.py | 99 +- .../mlir/dialects/_bufferization_ops_ext.py | 57 +- mlir/python/mlir/dialects/_builtin_ops_ext.py | 20 +- mlir/python/mlir/dialects/_func_ops_ext.py | 575 ++--- mlir/python/mlir/dialects/_linalg_ops_ext.py | 54 +- .../mlir/dialects/_loop_transform_ops_ext.py | 213 +- mlir/python/mlir/dialects/_memref_ops_ext.py | 46 +- .../mlir/dialects/_ml_program_ops_ext.py | 199 +- mlir/python/mlir/dialects/_ods_common.py | 262 +- mlir/python/mlir/dialects/_pdl_ops_ext.py | 428 ++-- mlir/python/mlir/dialects/_scf_ops_ext.py | 187 +- .../dialects/_structured_transform_ops_ext.py | 545 +++-- mlir/python/mlir/dialects/_tensor_ops_ext.py | 62 +- .../mlir/dialects/_transform_ops_ext.py | 227 +- .../mlir/dialects/linalg/opdsl/dump_oplib.py | 97 +- .../mlir/dialects/linalg/opdsl/lang/affine.py | 373 ++- .../linalg/opdsl/lang/comprehension.py | 1202 ++++----- .../mlir/dialects/linalg/opdsl/lang/config.py | 846 ++++--- .../mlir/dialects/linalg/opdsl/lang/dsl.py | 288 ++- .../dialects/linalg/opdsl/lang/emitter.py | 1078 ++++---- .../dialects/linalg/opdsl/lang/scalar_expr.py | 200 +- .../mlir/dialects/linalg/opdsl/lang/types.py | 44 +- .../dialects/linalg/opdsl/lang/yaml_helper.py | 43 +- .../linalg/opdsl/ops/core_named_ops.py | 2175 +++++++++-------- mlir/python/mlir/dialects/python_test.py | 6 +- .../mlir/dialects/transform/__init__.py | 18 +- mlir/python/mlir/execution_engine.py | 58 +- mlir/python/mlir/ir.py | 67 +- mlir/python/mlir/runtime/np_to_memref.py | 175 +- 30 files changed, 5091 insertions(+), 4742 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 7d3d1f6ca..03fcb1013 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -10,26 +10,26 @@ def get_lib_dirs() -> Sequence[str]: - """Gets the lib directory for linking to shared libraries. + """Gets the lib directory for linking to shared libraries. - On some platforms, the package may need to be built specially to export - development libraries. - """ - return [_this_dir] + On some platforms, the package may need to be built specially to export + development libraries. + """ + return [_this_dir] def get_include_dirs() -> Sequence[str]: - """Gets the include directory for compiling against exported C libraries. + """Gets the include directory for compiling against exported C libraries. - Depending on how the package was build, development C libraries may or may - not be present. - """ - return [os.path.join(_this_dir, "include")] + Depending on how the package was build, development C libraries may or may + not be present. + """ + return [os.path.join(_this_dir, "include")] # Perform Python level site initialization. This involves: # 1. Attempting to load initializer modules, specific to the distribution. -# 2. Defining the concrete mlir.ir.Context that does site specific +# 2. Defining the concrete mlir.ir.Context that does site specific # initialization. # # Aside from just being far more convenient to do this at the Python level, @@ -38,91 +38,106 @@ def get_include_dirs() -> Sequence[str]: # in the scope of the base class __init__). # # For #1, we: -# a. Probe for modules named '_mlirRegisterEverything' and -# '_site_initialize_{i}', where 'i' is a number starting at zero and +# a. Probe for modules named '_mlirRegisterEverything' and +# '_site_initialize_{i}', where 'i' is a number starting at zero and # proceeding so long as a module with the name is found. # b. If the module has a 'register_dialects' attribute, it will be called # immediately with a DialectRegistry to populate. # c. If the module has a 'context_init_hook', it will be added to a list -# of callbacks that are invoked as the last step of Context +# of callbacks that are invoked as the last step of Context # initialization (and passed the Context under construction). # # This facility allows downstreams to customize Context creation to their # needs. def _site_initialize(): - import importlib - import itertools - import logging - from ._mlir import ir - logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() - post_init_hooks = [] - - def process_initializer_module(module_name): - try: - m = importlib.import_module(f".{module_name}", __name__) - except ModuleNotFoundError: - return False - except ImportError: - message = (f"Error importing mlir initializer {module_name}. This may " - "happen in unclean incremental builds but is likely a real bug if " - "encountered otherwise and the MLIR Python API may not function.") - logger.warning(message, exc_info=True) - - logger.debug("Initializing MLIR with module: %s", module_name) - if hasattr(m, "register_dialects"): - logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) - if hasattr(m, "context_init_hook"): - logger.debug("Adding context init hook from %r", m) - post_init_hooks.append(m.context_init_hook) - return True - - - # If _mlirRegisterEverything is built, then include it as an initializer - # module. - process_initializer_module("_mlirRegisterEverything") - - # Load all _site_initialize_{i} modules, where 'i' is a number starting - # at 0. - for i in itertools.count(): - module_name = f"_site_initialize_{i}" - if not process_initializer_module(module_name): - break - - class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) - for hook in post_init_hooks: - hook(self) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() - ir.Context = Context - - class MLIRError(Exception): - """ - An exception with diagnostic information. Has the following fields: - message: str - error_diagnostics: List[ir.DiagnosticInfo] - """ - def __init__(self, message, error_diagnostics): - self.message = message - self.error_diagnostics = error_diagnostics - super().__init__(message, error_diagnostics) - - def __str__(self): - s = self.message - if self.error_diagnostics: - s += ':' - for diag in self.error_diagnostics: - s += "\nerror: " + str(diag.location)[4:-1] + ": " + diag.message.replace('\n', '\n ') - for note in diag.notes: - s += "\n note: " + str(note.location)[4:-1] + ": " + note.message.replace('\n', '\n ') - return s - ir.MLIRError = MLIRError + import importlib + import itertools + import logging + from ._mlir import ir + + logger = logging.getLogger(__name__) + registry = ir.DialectRegistry() + post_init_hooks = [] + + def process_initializer_module(module_name): + try: + m = importlib.import_module(f".{module_name}", __name__) + except ModuleNotFoundError: + return False + except ImportError: + message = ( + f"Error importing mlir initializer {module_name}. This may " + "happen in unclean incremental builds but is likely a real bug if " + "encountered otherwise and the MLIR Python API may not function." + ) + logger.warning(message, exc_info=True) + + logger.debug("Initializing MLIR with module: %s", module_name) + if hasattr(m, "register_dialects"): + logger.debug("Registering dialects from initializer %r", m) + m.register_dialects(registry) + if hasattr(m, "context_init_hook"): + logger.debug("Adding context init hook from %r", m) + post_init_hooks.append(m.context_init_hook) + return True + + # If _mlirRegisterEverything is built, then include it as an initializer + # module. + process_initializer_module("_mlirRegisterEverything") + + # Load all _site_initialize_{i} modules, where 'i' is a number starting + # at 0. + for i in itertools.count(): + module_name = f"_site_initialize_{i}" + if not process_initializer_module(module_name): + break + + class Context(ir._BaseContext): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.append_dialect_registry(registry) + for hook in post_init_hooks: + hook(self) + # TODO: There is some debate about whether we should eagerly load + # all dialects. It is being done here in order to preserve existing + # behavior. See: https://github.com/llvm/llvm-project/issues/56037 + self.load_all_available_dialects() + + ir.Context = Context + + class MLIRError(Exception): + """ + An exception with diagnostic information. Has the following fields: + message: str + error_diagnostics: List[ir.DiagnosticInfo] + """ + + def __init__(self, message, error_diagnostics): + self.message = message + self.error_diagnostics = error_diagnostics + super().__init__(message, error_diagnostics) + + def __str__(self): + s = self.message + if self.error_diagnostics: + s += ":" + for diag in self.error_diagnostics: + s += ( + "\nerror: " + + str(diag.location)[4:-1] + + ": " + + diag.message.replace("\n", "\n ") + ) + for note in diag.notes: + s += ( + "\n note: " + + str(note.location)[4:-1] + + ": " + + note.message.replace("\n", "\n ") + ) + return s + + ir.MLIRError = MLIRError _site_initialize() diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py index 240859352..df38f8717 100644 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ b/mlir/python/mlir/dialects/_arith_ops_ext.py @@ -3,72 +3,67 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context - from typing import Any, List, Union + from typing import Any, List, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True + try: + cls(obj) + except ValueError: + return False + return True def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) + return any(_isa(obj, cls) for cls in classes) def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) + return _is_any_of(type, [IntegerType, IndexType]) def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) class ConstantOp: - """Specialization for the constant op class.""" - - def __init__(self, - result: Type, - value: Union[int, float, Attribute], - *, - loc=None, - ip=None): - if isinstance(value, int): - super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, float): - super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) - else: - super().__init__(value, loc=loc, ip=ip) - - @classmethod - def create_index(cls, value: int, *, loc=None, ip=None): - """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), - value, - loc=loc, - ip=ip) - - @property - def type(self): - return self.results[0].type - - @property - def value(self): - return Attribute(self.operation.attributes["value"]) - - @property - def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): - return IntegerAttr(self.value).value - elif _is_float_type(self.type): - return FloatAttr(self.value).value - else: - raise ValueError("only integer and float constants have literal values") + """Specialization for the constant op class.""" + + def __init__( + self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + ): + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip + ) + + @property + def type(self): + return self.results[0].type + + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py index 6ed35f444..1066cb4c7 100644 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_ops_ext.py @@ -3,36 +3,39 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context + from typing import Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context - from typing import Any, List, Union + from typing import Any, List, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e class AllocTensorOp: - """Extends the bufferization.alloc_tensor op.""" + """Extends the bufferization.alloc_tensor op.""" - def __init__(self, - tensor_type: Type, - dynamic_sizes: Sequence[Value], - copy: Value, - size_hint: Value, - escape: BoolAttr, - *, - loc=None, - ip=None): - """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" - context = get_default_loc_context(loc) - attributes = {} - if escape: - attributes["escape"] = escape - op = self.build_generic( - results=[tensor_type], - operands=[dynamic_sizes, copy, size_hint], - attributes=attributes, - loc=loc, - ip=ip) - OpView.__init__(self, op) + def __init__( + self, + tensor_type: Type, + dynamic_sizes: Sequence[Value], + copy: Value, + size_hint: Value, + escape: BoolAttr, + *, + loc=None, + ip=None + ): + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" + context = get_default_loc_context(loc) + attributes = {} + if escape: + attributes["escape"] = escape + op = self.build_generic( + results=[tensor_type], + operands=[dynamic_sizes, copy, size_hint], + attributes=attributes, + loc=loc, + ip=ip, + ) + OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py index b69163fa4..27a601230 100644 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ b/mlir/python/mlir/dialects/_builtin_ops_ext.py @@ -3,18 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e + class ModuleOp: - """Specialization for the module op class.""" + """Specialization for the module op class.""" - def __init__(self, *, loc=None, ip=None): - super().__init__(self.build_generic(results=[], operands=[], loc=loc, - ip=ip)) - body = self.regions[0].blocks.append() + def __init__(self, *, loc=None, ip=None): + super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) + body = self.regions[0].blocks.append() - @property - def body(self): - return self.regions[0].blocks[0] + @property + def body(self): + return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py index 56df423d3..6d264c33f 100644 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ b/mlir/python/mlir/dialects/_func_ops_ext.py @@ -3,298 +3,317 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context - import inspect + import inspect - from typing import Any, List, Optional, Sequence, Union + from typing import Any, List, Optional, Sequence, Union except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" RESULT_ATTRIBUTE_NAME = "res_attrs" + class ConstantOp: - """Specialization for the constant op class.""" + """Specialization for the constant op class.""" - def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) + def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): + super().__init__(result, value, loc=loc, ip=ip) - @property - def type(self): - return self.results[0].type + @property + def type(self): + return self.results[0].type class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func(FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and (param.kind - == param.POSITIONAL_OR_KEYWORD or - param.kind == param.KEYWORD_ONLY): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None.") - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get(inputs=inputs, results=return_types) - func_op.attributes["function_type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp(return_types, FlatSymbolRefAttr.get(symbol_name), - call_args) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func( + FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None, + ): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and ( + param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY + ): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results + ) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None." + ) + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get( + inputs=inputs, results=return_types + ) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp( + return_types, FlatSymbolRefAttr.get(symbol_name), call_args + ) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator - return decorator class CallOp: - """Specialization for the call op class.""" - - def __init__(self, - calleeOrResults: Union[FuncOp, List[Type]], - argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], - arguments: Optional[List] = None, - *, - loc=None, - ip=None): - """Creates an call operation. - - The constructor accepts three different forms: - - 1. A function op to be called followed by a list of arguments. - 2. A list of result types, followed by the name of the function to be - called as string, following by a list of arguments. - 3. A list of result types, followed by the name of the function to be - called as symbol reference attribute, followed by a list of arguments. - - For example - - f = func.FuncOp("foo", ...) - func.CallOp(f, [args]) - func.CallOp([result_types], "foo", [args]) - - In all cases, the location and insertion point may be specified as keyword - arguments if not provided by the surrounding context managers. - """ - - # TODO: consider supporting constructor "overloads", e.g., through a custom - # or pybind-provided metaclass. - if isinstance(calleeOrResults, FuncOp): - if not isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function, expected " + - "the second argument to be a list of call arguments, " + - f"got {type(argumentsOrCallee)}") - if arguments is not None: - raise ValueError("unexpected third argument when constructing a call" + - "to a function") - - super().__init__( - calleeOrResults.type.results, - FlatSymbolRefAttr.get( - calleeOrResults.name.value, - context=_get_default_loc_context(loc)), - argumentsOrCallee, - loc=loc, - ip=ip) - return - - if isinstance(argumentsOrCallee, list): - raise ValueError("when constructing a call to a function by name, " + - "expected the second argument to be a string or a " + - f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}") - - if isinstance(argumentsOrCallee, FlatSymbolRefAttr): - super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip) - elif isinstance(argumentsOrCallee, str): - super().__init__( - calleeOrResults, - FlatSymbolRefAttr.get( - argumentsOrCallee, context=_get_default_loc_context(loc)), - arguments, - loc=loc, - ip=ip) + """Specialization for the call op class.""" + + def __init__( + self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None, + ): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = func.FuncOp("foo", ...) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}" + ) + if arguments is not None: + raise ValueError( + "unexpected third argument when constructing a call" + + "to a function" + ) + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, context=_get_default_loc_context(loc) + ), + argumentsOrCallee, + loc=loc, + ip=ip, + ) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" + ) + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip + ) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc) + ), + arguments, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py index eb9e969f3..3f6d854ca 100644 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ b/mlir/python/mlir/dialects/_linalg_ops_ext.py @@ -3,39 +3,45 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Optional, Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region + from typing import Optional, Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context + from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from ._ods_common import get_op_result_or_value as _get_op_result_or_value + def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False + try: + cls(ty) + return True + except ValueError: + return False class StructuredOpMixin: - """All structured ops use the same mixin class.""" + """All structured ops use the same mixin class.""" - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - super().__init__( - self.build_generic(results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip)) + def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): + super().__init__( + self.build_generic( + results=list(results), + operands=[list(inputs), list(outputs)], + loc=loc, + ip=ip, + ) + ) def select_opview_mixin(parent_opview_cls): - # TODO: This shouldn't be a heuristic: we should have a way to annotate - # the OpView to note that it is a structured op. - if ("__init__" not in parent_opview_cls.__dict__ and - hasattr(parent_opview_cls, "inputs") and - hasattr(parent_opview_cls, "outputs") and - hasattr(parent_opview_cls, "result_tensors")): - return StructuredOpMixin + # TODO: This shouldn't be a heuristic: we should have a way to annotate + # the OpView to note that it is a structured op. + if ( + "__init__" not in parent_opview_cls.__dict__ + and hasattr(parent_opview_cls, "inputs") + and hasattr(parent_opview_cls, "outputs") + and hasattr(parent_opview_cls, "result_tensors") + ): + return StructuredOpMixin diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index 10079d32f..3536d45ab 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -3,125 +3,130 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Union class GetParentForOp: - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) + """Extension for GetParentForOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 + super().__init__( + result_type, + _get_op_result_or_value(target), + num_loops=num_loops, + ip=ip, + loc=loc, + ) class LoopOutlineOp: - """Extension for LoopOutlineOp.""" - - def __init__( - self, - function_type: Type, - call_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None, - ): - super().__init__( - function_type, - call_type, - _get_op_result_or_value(target), - func_name=(func_name if isinstance(func_name, StringAttr) else - StringAttr.get(func_name)), - ip=ip, - loc=loc, - ) + """Extension for LoopOutlineOp.""" + + def __init__( + self, + function_type: Type, + call_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): + super().__init__( + function_type, + call_type, + _get_op_result_or_value(target), + func_name=( + func_name + if isinstance(func_name, StringAttr) + else StringAttr.get(func_name) + ), + ip=ip, + loc=loc, + ) class LoopPeelOp: - """Extension for LoopPeelOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - fail_if_already_divisible=(fail_if_already_divisible if isinstance( - fail_if_already_divisible, BoolAttr) else - BoolAttr.get(fail_if_already_divisible)), - ip=ip, - loc=loc, - ) + """Extension for LoopPeelOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + fail_if_already_divisible=( + fail_if_already_divisible + if isinstance(fail_if_already_divisible, BoolAttr) + else BoolAttr.get(fail_if_already_divisible) + ), + ip=ip, + loc=loc, + ) class LoopPipelineOp: - """Extension for LoopPipelineOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None, - ): - if iteration_interval is None: - iteration_interval = 1 - if read_latency is None: - read_latency = 10 - super().__init__( - result_type, - _get_op_result_or_value(target), - iteration_interval=iteration_interval, - read_latency=read_latency, - ip=ip, - loc=loc, - ) + """Extension for LoopPipelineOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 + super().__init__( + result_type, + _get_op_result_or_value(target), + iteration_interval=iteration_interval, + read_latency=read_latency, + ip=ip, + loc=loc, + ) class LoopUnrollOp: - """Extension for LoopUnrollOp.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None, - ): - super().__init__( - _get_op_result_or_value(target), - factor=factor, - ip=ip, - loc=loc, - ) + """Extension for LoopUnrollOp.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): + super().__init__( + _get_op_result_or_value(target), + factor=factor, + ip=ip, + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py index a00a087be..825f1a0a7 100644 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_ops_ext.py @@ -3,34 +3,34 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ._ods_common import get_op_results_or_values as _get_op_results_or_values except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Sequence, Union class LoadOp: - """Specialization for the MemRef load operation.""" + """Specialization for the MemRef load operation.""" - def __init__(self, - memref: Union[Operation, OpView, Value], - indices: Optional[Union[Operation, OpView, - Sequence[Value]]] = None, - *, - loc=None, - ip=None): - """Creates a memref load operation. + def __init__( + self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None + ): + """Creates a memref load operation. - Args: - memref: the buffer to load from. - indices: the list of subscripts, may be empty for zero-dimensional - buffers. - loc: user-visible location of the operation. - ip: insertion point. - """ - indices_resolved = [] if indices is None else _get_op_results_or_values( - indices) - super().__init__(memref, indices_resolved, loc=loc, ip=ip) + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + indices_resolved = [] if indices is None else _get_op_results_or_values(indices) + super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py index 8db82cf81..c84d23c16 100644 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ b/mlir/python/mlir/dialects/_ml_program_ops_ext.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from typing import Union - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context + from typing import Union + from ..ir import * + from ._ods_common import get_default_loc_context as _get_default_loc_context except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from ._ml_program_ops_gen import * @@ -17,100 +17,97 @@ class FuncOp: - """Specialization for the func op class.""" - - def __init__(self, - name, - type, - *, - visibility=None, - body_builder=None, - loc=None, - ip=None): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = StringAttr.get( - str(visibility)) if visibility is not None else None - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError('External function does not have a body') - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError('The function already has an entry block!') - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 51b900819..7655629a5 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -18,144 +18,152 @@ def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). A name-based match is attempted first before falling back - to a below mechanism. - - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - # First try to resolve by name. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - # Fall back to a select_opview_mixin hook. - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}") from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator + """Decorator to extend an OpView class from an extension module. + + Extension modules can expose various entry-points: + Stand-alone class with the same name as a parent OpView class (i.e. + "ReturnOp"). A name-based match is attempted first before falling back + to a below mechanism. + + def select_opview_mixin(parent_opview_cls): + If defined, allows an appropriate mixin class to be selected dynamically + based on the parent OpView class. Should return NotImplemented if a + decision is not made. + + Args: + ext_module: A module from which to locate extensions. Can be None if not + available. + + Returns: + A decorator that takes an OpView subclass and further extends it as + needed. + """ + + def class_decorator(parent_opview_cls: type): + if ext_module is None: + return parent_opview_cls + mixin_cls = NotImplemented + # First try to resolve by name. + try: + mixin_cls = getattr(ext_module, parent_opview_cls.__name__) + except AttributeError: + # Fall back to a select_opview_mixin hook. + try: + select_mixin = getattr(ext_module, "select_opview_mixin") + except AttributeError: + pass + else: + mixin_cls = select_mixin(parent_opview_cls) + + if mixin_cls is NotImplemented or mixin_cls is None: + return parent_opview_cls + + # Have a mixin_cls. Create an appropriate subclass. + try: + + class LocalOpView(mixin_cls, parent_opview_cls): + pass + + except TypeError as e: + raise TypeError( + f"Could not mixin {mixin_cls} into {parent_opview_cls}" + ) from e + LocalOpView.__name__ = parent_opview_cls.__name__ + LocalOpView.__qualname__ = parent_opview_cls.__qualname__ + return LocalOpView + + return class_decorator def segmented_accessor(elements, raw_segments, idx): - """ - Returns a slice of elements corresponding to the idx-th segment. - - elements: a sliceable container (operands or results). - raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing - sizes of the segments. - idx: index of the segment. - """ - segments = _cext.ir.DenseI32ArrayAttr(raw_segments) - start = sum(segments[i] for i in range(idx)) - end = start + segments[idx] - return elements[start:end] - - -def equally_sized_accessor(elements, n_variadic, n_preceding_simple, - n_preceding_variadic): - """ - Returns a starting position and a number of elements per variadic group - assuming equally-sized groups and the given numbers of preceding groups. - - elements: a sequential container. - n_variadic: the number of variadic groups in the container. - n_preceding_simple: the number of non-variadic groups preceding the current - group. - n_preceding_variadic: the number of variadic groups preceding the current - group. - """ - - total_variadic_length = len(elements) - n_variadic + 1 - # This should be enforced by the C++-side trait verifier. - assert total_variadic_length % n_variadic == 0 - - elements_per_group = total_variadic_length // n_variadic - start = n_preceding_simple + n_preceding_variadic * elements_per_group - return start, elements_per_group + """ + Returns a slice of elements corresponding to the idx-th segment. + + elements: a sliceable container (operands or results). + raw_segments: an mlir.ir.Attribute, of DenseI32Array subclass containing + sizes of the segments. + idx: index of the segment. + """ + segments = _cext.ir.DenseI32ArrayAttr(raw_segments) + start = sum(segments[i] for i in range(idx)) + end = start + segments[idx] + return elements[start:end] + + +def equally_sized_accessor( + elements, n_variadic, n_preceding_simple, n_preceding_variadic +): + """ + Returns a starting position and a number of elements per variadic group + assuming equally-sized groups and the given numbers of preceding groups. + + elements: a sequential container. + n_variadic: the number of variadic groups in the container. + n_preceding_simple: the number of non-variadic groups preceding the current + group. + n_preceding_variadic: the number of variadic groups preceding the current + group. + """ + + total_variadic_length = len(elements) - n_variadic + 1 + # This should be enforced by the C++-side trait verifier. + assert total_variadic_length % n_variadic == 0 + + elements_per_group = total_variadic_length // n_variadic + start = n_preceding_simple + n_preceding_variadic * elements_per_group + return start, elements_per_group def get_default_loc_context(location=None): - """ - Returns a context in which the defaulted location is created. If the location - is None, takes the current location from the stack, raises ValueError if there - is no location on the stack. - """ - if location is None: - # Location.current raises ValueError if there is no current location. - return _cext.ir.Location.current.context - return location.context + """ + Returns a context in which the defaulted location is created. If the location + is None, takes the current location from the stack, raises ValueError if there + is no location on the stack. + """ + if location is None: + # Location.current raises ValueError if there is no current location. + return _cext.ir.Location.current.context + return location.context def get_op_result_or_value( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList] + arg: _Union[ + _cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value, _cext.ir.OpResultList + ] ) -> _cext.ir.Value: - """Returns the given value or the single result of the given op. - - This is useful to implement op constructors so that they can take other ops as - arguments instead of requiring the caller to extract results for every op. - Raises ValueError if provided with an op that doesn't have a single result. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.result - elif isinstance(arg, _cext.ir.Operation): - return arg.result - elif isinstance(arg, _cext.ir.OpResultList): - return arg[0] - else: - assert isinstance(arg, _cext.ir.Value) - return arg + """Returns the given value or the single result of the given op. + + This is useful to implement op constructors so that they can take other ops as + arguments instead of requiring the caller to extract results for every op. + Raises ValueError if provided with an op that doesn't have a single result. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.result + elif isinstance(arg, _cext.ir.Operation): + return arg.result + elif isinstance(arg, _cext.ir.OpResultList): + return arg[0] + else: + assert isinstance(arg, _cext.ir.Value) + return arg def get_op_results_or_values( - arg: _Union[_cext.ir.OpView, _cext.ir.Operation, - _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]]] + arg: _Union[ + _cext.ir.OpView, + _cext.ir.Operation, + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + ] ) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: - """Returns the given sequence of values or the results of the given op. - - This is useful to implement op constructors so that they can take other ops as - lists of arguments instead of requiring the caller to extract results for - every op. - """ - if isinstance(arg, _cext.ir.OpView): - return arg.operation.results - elif isinstance(arg, _cext.ir.Operation): - return arg.results - else: - return [get_op_result_or_value(element) for element in arg] + """Returns the given sequence of values or the results of the given op. + + This is useful to implement op constructors so that they can take other ops as + lists of arguments instead of requiring the caller to extract results for + every op. + """ + if isinstance(arg, _cext.ir.OpView): + return arg.operation.results + elif isinstance(arg, _cext.ir.Operation): + return arg.results + else: + return [get_op_result_or_value(element) for element in arg] diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py index 40ccbef63..fc9de0b7f 100644 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ b/mlir/python/mlir/dialects/_pdl_ops_ext.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ..dialects import pdl + from ..ir import * + from ..dialects import pdl except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Union, Optional, Sequence, Mapping from ._ods_common import ( @@ -16,264 +16,256 @@ class ApplyNativeConstraintOp: - """Specialization for PDL apply native constraint op class.""" - - def __init__( - self, - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) + """Specialization for PDL apply native constraint op class.""" + + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(name, args, loc=loc, ip=ip) class ApplyNativeRewriteOp: - """Specialization for PDL apply native rewrite op class.""" - - def __init__( - self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) + """Specialization for PDL apply native rewrite op class.""" + + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(results, name, args, loc=loc, ip=ip) class AttributeOp: - """Specialization for PDL attribute op class.""" + """Specialization for PDL attribute op class.""" - def __init__( - self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None, - ): - valueType = valueType if valueType is None else _get_value(valueType) - result = pdl.AttributeType.get() - super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): + valueType = valueType if valueType is None else _get_value(valueType) + result = pdl.AttributeType.get() + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) class EraseOp: - """Specialization for PDL erase op class.""" + """Specialization for PDL erase op class.""" - def __init__( - self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - operation = _get_value(operation) - super().__init__(operation, loc=loc, ip=ip) + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) class OperandOp: - """Specialization for PDL operand op class.""" + """Specialization for PDL operand op class.""" - def __init__( - self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - type = type if type is None else _get_value(type) - result = pdl.ValueType.get() - super().__init__(result, valueType=type, loc=loc, ip=ip) + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, valueType=type, loc=loc, ip=ip) class OperandsOp: - """Specialization for PDL operands op class.""" + """Specialization for PDL operands op class.""" - def __init__( - self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - types = types if types is None else _get_value(types) - result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, valueType=types, loc=loc, ip=ip) + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, valueType=types, loc=loc, ip=ip) class OperationOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - name: Optional[Union[str, StringAttr]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - attributes: Optional[Mapping[str, Union[OpView, Operation, - Value]]] = None, - types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if types is None: - types = [] - if attributes is None: - attributes = {} - if args is None: - args = [] - args = _get_values(args) - attrNames = [] - attrValues = [] - for attrName, attrValue in attributes.items(): - attrNames.append(StringAttr.get(attrName)) - attrValues.append(_get_value(attrValue)) - attrNames = ArrayAttr.get(attrNames) - types = _get_values(types) - result = pdl.OperationType.get() - super().__init__(result, - args, - attrValues, - attrNames, - types, - opName=name, - loc=loc, - ip=ip) + """Specialization for PDL operand op class.""" + + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + args = _get_values(args) + attrNames = [] + attrValues = [] + for attrName, attrValue in attributes.items(): + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__( + result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip + ) class PatternOp: - """Specialization for PDL pattern op class.""" - - def __init__( - self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None, - ): - """Creates an PDL `pattern` operation.""" - super().__init__(benefit, sym_name=name, loc=loc, ip=ip) - self.regions[0].blocks.append() - - @property - def body(self): - """Return the body (block) of the pattern.""" - return self.regions[0].blocks[0] + """Specialization for PDL pattern op class.""" + + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + """Creates an PDL `pattern` operation.""" + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] class ReplaceOp: - """Specialization for PDL replace op class.""" - - def __init__( - self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - loc=None, - ip=None, - ): - if with_values is None: - with_values = [] - op = _get_value(op) - with_op = with_op if with_op is None else _get_value(with_op) - with_values = _get_values(with_values) - super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) + """Specialization for PDL replace op class.""" + + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) class ResultOp: - """Specialization for PDL result op class.""" + """Specialization for PDL result op class.""" - def __init__( - self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) class ResultsOp: - """Specialization for PDL results op class.""" + """Specialization for PDL results op class.""" - def __init__( - self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - super().__init__(result, parent, index=index, loc=loc, ip=ip) + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + super().__init__(result, parent, index=index, loc=loc, ip=ip) class RewriteOp: - """Specialization for PDL rewrite op class.""" - - def __init__( - self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - root = root if root is None else _get_value(root) - args = _get_values(args) - super().__init__(args, root=root, name=name, loc=loc, ip=ip) - - def add_body(self): - """Add body (block) to the rewrite.""" - self.regions[0].blocks.append() - return self.body - - @property - def body(self): - """Return the body (block) of the rewrite.""" - return self.regions[0].blocks[0] + """Specialization for PDL rewrite op class.""" + + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + root = root if root is None else _get_value(root) + args = _get_values(args) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] class TypeOp: - """Specialization for PDL type op class.""" + """Specialization for PDL type op class.""" - def __init__(self, - constantType: Optional[Union[TypeAttr, Type]] = None, - *, - loc=None, - ip=None): - result = pdl.TypeType.get() - super().__init__(result, constantType=constantType, loc=loc, ip=ip) + def __init__( + self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None + ): + result = pdl.TypeType.get() + super().__init__(result, constantType=constantType, loc=loc, ip=ip) class TypesOp: - """Specialization for PDL types op class.""" - - def __init__( - self, - constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, - *, - loc=None, - ip=None, - ): - if constantTypes is None: - constantTypes = [] - result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) + """Specialization for PDL types op class.""" + + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index 3c3e67302..4b2519ef3 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -3,105 +3,104 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Any, Optional, Sequence, Union -from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) + class ForOp: - """Specialization for the SCF for op class.""" - - def __init__(self, - lower_bound, - upper_bound, - step, - iter_args: Optional[Union[Operation, OpView, - Sequence[Value]]] = None, - *, - loc=None, - ip=None): - """Creates an SCF `for` operation. - - - `lower_bound` is the value to use as lower bound of the loop. - - `upper_bound` is the value to use as upper bound of the loop. - - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments or an operation - producing them as results. - """ - if iter_args is None: - iter_args = [] - iter_args = _get_op_results_or_values(iter_args) - - results = [arg.type for arg in iter_args] - super().__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) - for o in [lower_bound, upper_bound, step] - ] + list(iter_args), - loc=loc, - ip=ip)) - self.regions[0].blocks.append(IndexType.get(), *results) - - @property - def body(self): - """Returns the body (block) of the loop.""" - return self.regions[0].blocks[0] - - @property - def induction_variable(self): - """Returns the induction variable of the loop.""" - return self.body.arguments[0] - - @property - def inner_iter_args(self): - """Returns the loop-carried arguments usable within the loop. - - To obtain the loop-carried operands, use `iter_args`. - """ - return self.body.arguments[1:] + """Specialization for the SCF for op class.""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None + ): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super().__init__( + self.build_generic( + regions=1, + results=results, + operands=[ + _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] + ] + + list(iter_args), + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] class IfOp: - """Specialization for the SCF if op class.""" - - def __init__(self, - cond, - results_=[], - *, - hasElse=False, - loc=None, - ip=None): - """Creates an SCF `if` operation. - - - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. - - `hasElse` determines whether the if operation has the else branch. - """ - operands = [] - operands.append(cond) - results = [] - results.extend(results_) - super().__init__( - self.build_generic( - regions=2, - results=results, - operands=operands, - loc=loc, - ip=ip)) - self.regions[0].blocks.append(*[]) - if hasElse: - self.regions[1].blocks.append(*[]) - - @property - def then_block(self): - """Returns the then block of the if operation.""" - return self.regions[0].blocks[0] - - @property - def else_block(self): - """Returns the else block of the if operation.""" - return self.regions[1].blocks[0] + """Specialization for the SCF if op class.""" + + def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super().__init__( + self.build_generic( + regions=2, results=results, operands=operands, loc=loc, ip=ip + ) + ) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 9c051cd3d..30dafff6a 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ..dialects import pdl, transform + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ..dialects import pdl, transform except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import List, Optional, Sequence, Union, overload @@ -16,312 +16,315 @@ def _get_int_int_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, - IntOrAttrList]]]] + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: - """Creates an array attribute containing array attributes of integers. + """Creates an array attribute containing array attributes of integers. If the operand is already an array attribute, forwards it. Otherwise treats the operand as a list of attributes or integers, potentially interpserced, to create a new array-of-array attribute. Expects the thread-local MLIR context to have been set by the context manager. """ - if values is None: - return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - if isinstance(values, list): - values = [ - ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) - for v in value]) - for value in values - ] + if values is None: + return ArrayAttr.get([]) + if isinstance(values, ArrayAttr): + return values + if isinstance(values, list): + values = [ + ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value] + ) + for value in values + ] - return ArrayAttr.get(values) + return ArrayAttr.get(values) class DecomposeOp: - """Specialization for DecomposeOp class.""" + """Specialization for DecomposeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip + ) class GeneralizeOp: - """Specialization for GeneralizeOp class.""" + """Specialization for GeneralizeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip + ) class InterchangeOp: - """Specialization for InterchangeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - iterator_interchange=iterator_interchange, - loc=loc, - ip=ip, - ) + """Specialization for InterchangeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + iterator_interchange=iterator_interchange, + loc=loc, + ip=ip, + ) class MatchOp: - """Specialization for MatchOp class.""" - - @classmethod - def match_op_names( - MatchOp, - target: Union[Operation, Value], - names: Sequence[str], - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, - _get_op_result_or_value(target), - ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), - loc=loc, - ip=ip, - ) + """Specialization for MatchOp class.""" + + @classmethod + def match_op_names( + MatchOp, + target: Union[Operation, Value], + names: Sequence[str], + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + return MatchOp( + pdl_operation_type, + _get_op_result_or_value(target), + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip, + ) class MultiTileSizesOp: - """Specialization for MultitileSizesOp class.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, - loc=None, - ip=None, - ): - if divisor is None: - divisor = 1 - super().__init__( - result_type, - result_type, - result_type, - _get_op_result_or_value(target), - dimension=dimension, - target_size=target_size, - divisor=divisor, - loc=loc, - ip=ip, - ) + """Specialization for MultitileSizesOp class.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + if divisor is None: + divisor = 1 + super().__init__( + result_type, + result_type, + result_type, + _get_op_result_or_value(target), + dimension=dimension, + target_size=target_size, + divisor=divisor, + loc=loc, + ip=ip, + ) class PadOp: - """Specialization for PadOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - padding_values: Optional[Optional[Union[ArrayAttr, - Sequence[Attribute]]]] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[Union[ArrayAttr, Sequence[Union[ - ArrayAttr, IntOrAttrList]]]] = None, - loc=None, - ip=None, - ): - if transpose_paddings is None: - transpose_paddings = [] - if pack_paddings is None: - pack_paddings = [] - if padding_dimensions is None: - padding_dimensions = [] - if padding_values is None: - padding_values = [] - pdl_operation_type = pdl.OperationType.get() - transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings_attr, - loc=loc, - ip=ip, - ) + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + padding_values: Optional[ + Optional[Union[ArrayAttr, Sequence[Attribute]]] + ] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + loc=None, + ip=None, + ): + if transpose_paddings is None: + transpose_paddings = [] + if pack_paddings is None: + pack_paddings = [] + if padding_dimensions is None: + padding_dimensions = [] + if padding_values is None: + padding_values = [] + pdl_operation_type = pdl.OperationType.get() + transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings_attr, + loc=loc, + ip=ip, + ) class ScalarizeOp: - """Specialization for ScalarizeOp class.""" + """Specialization for ScalarizeOp class.""" - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - super().__init__(pdl_operation_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip + ) class SplitOp: - """Specialization for SplitOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None, - ): - if isinstance(split_point, int): - static_split_point = split_point - dynamic_split_point = None - else: - static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = _get_op_result_or_value(split_point) - - target = _get_op_result_or_value(target) - - super().__init__( - target.type, - target.type, - target, - dimension=dimension, - static_split_point=static_split_point, - dynamic_split_point=dynamic_split_point, - loc=loc, - ip=ip, - ) + """Specialization for SplitOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): + if isinstance(split_point, int): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = ShapedType.get_dynamic_size() + dynamic_split_point = _get_op_result_or_value(split_point) + + target = _get_op_result_or_value(target) + + super().__init__( + target.type, + target.type, + target, + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip, + ) class TileOp: - """Specialization for TileOp class.""" - - @overload - def __init__( - self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[Sequence[Union[int, IntegerAttr, Operation, Value]], - ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - if interchange is None: - interchange = [] - if sizes is None: - sizes = [] - - static_sizes = [] - dynamic_sizes = [] - if isinstance(sizes, ArrayAttr): - sizes_attr = sizes - else: - for size in sizes: - if isinstance(size, int): - static_sizes.append(size) + """Specialization for TileOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[ + Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] + ] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + if interchange is None: + interchange = [] + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(size) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) + + num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = DenseI64ArrayAttr.get(static_sizes) - - num_loops = sum( - v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) - - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." - else: - loop_types = (([loop_types_or_target] * num_loops) if isinstance( - loop_types_or_target, Type) else loop_types_or_target) - target = target_or_none - - target = _get_op_result_or_value(target) - - super().__init__( - target.type, - loop_types, - target, - dynamic_sizes=dynamic_sizes, - static_sizes=sizes_attr, - interchange=interchange, - loc=loc, - ip=ip, - ) - - def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: - if not attr: - return [] - return [element for element in attr] + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + target = _get_op_result_or_value(target) + + super().__init__( + target.type, + loop_types, + target, + dynamic_sizes=dynamic_sizes, + static_sizes=sizes_attr, + interchange=interchange, + loc=loc, + ip=ip, + ) + + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: + if not attr: + return [] + return [element for element in attr] class VectorizeOp: - """Specialization for VectorizeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - vectorize_padding: Union[bool, BoolAttr] = False, - loc=None, - ip=None, - ): - pdl_operation_type = pdl.OperationType.get() - if isinstance(vectorize_padding, bool): - vectorize_padding = UnitAttr.get() - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - vectorize_padding=vectorize_padding, - loc=loc, - ip=ip, - ) + """Specialization for VectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + vectorize_padding: Union[bool, BoolAttr] = False, + loc=None, + ip=None, + ): + pdl_operation_type = pdl.OperationType.get() + if isinstance(vectorize_padding, bool): + vectorize_padding = UnitAttr.get() + super().__init__( + pdl_operation_type, + _get_op_result_or_value(target), + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py index 51d998b6e..09b9ec68d 100644 --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ b/mlir/python/mlir/dialects/_tensor_ops_ext.py @@ -3,40 +3,42 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * + from ..ir import * except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Any, Optional, Sequence, Union -from ._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) class EmptyOp: - """Extends the tensor.empty op.""" + """Extends the tensor.empty op.""" - def __init__(self, - sizes: Sequence[Union[int, Value]], - element_type: Type, - *, - loc=None, - ip=None): - """Constructs an `empty` with mixed static/dynamic sizes.""" - # TODO: Refactor the EmptyOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - dynamic_sizes = [] - static_sizes = [] - for s in sizes: - if isinstance(s, int): - static_sizes.append(s) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) - op = self.build_generic( - results=[result_type], - operands=dynamic_sizes, - attributes={}, - loc=loc, - ip=ip) - OpView.__init__(self, op) + def __init__( + self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None + ): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + op = self.build_generic( + results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip + ) + OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index cc4428ea5..425ec6585 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -3,144 +3,131 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e + raise RuntimeError("Error loading imports from extension module") from e from typing import Optional, Sequence, Union class CastOp: - - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None): - super().__init__(result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__( + self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class GetClosestIsolatedParentOp: - - def __init__(self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None): - super().__init__(result_type, - _get_op_result_or_value(target), - loc=loc, - ip=ip) + def __init__( + self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class MergeHandlesOp: - - def __init__( - self, - handles: Sequence[Union[Operation, Value]], - *, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h) for h in handles], - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class ReplicateOp: - - def __init__( - self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h).type for h in handles], - _get_op_result_or_value(pattern), - [_get_op_result_or_value(h) for h in handles], - loc=loc, - ip=ip, - ) + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h).type for h in handles], + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip, + ) class SequenceOp: - - def __init__( - self, - failure_propagation_mode, - results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[Union[Sequence[Value], Sequence[Type], Operation, - OpView]] = None, - ): - root = (_get_op_result_or_value(target) if isinstance( - target, (Operation, Value)) else None) - root_type = root.type if not isinstance(target, Type) else target - if not isinstance(failure_propagation_mode, Attribute): - failure_propagation_mode_attr = IntegerAttr.get( - IntegerType.get_signless(32), failure_propagation_mode._as_int()) - else: - failure_propagation_mode_attr = failure_propagation_mode - - if extra_bindings is None: - extra_bindings = [] - if isinstance(extra_bindings, (Operation, OpView)): - extra_bindings = _get_op_results_or_values(extra_bindings) - - extra_binding_types = [] - if len(extra_bindings) != 0: - if isinstance(extra_bindings[0], Type): - extra_binding_types = extra_bindings - extra_bindings = [] - else: - extra_binding_types = [v.type for v in extra_bindings] - - super().__init__( - results_=results, - failure_propagation_mode=failure_propagation_mode_attr, - root=root, - extra_bindings=extra_bindings, - ) - self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - @property - def bodyExtraArgs(self) -> BlockArgumentList: - return self.body.arguments[1:] + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + ): + root = ( + _get_op_result_or_value(target) + if isinstance(target, (Operation, Value)) + else None + ) + root_type = root.type if not isinstance(target, Type) else target + if not isinstance(failure_propagation_mode, Attribute): + failure_propagation_mode_attr = IntegerAttr.get( + IntegerType.get_signless(32), failure_propagation_mode._as_int() + ) + else: + failure_propagation_mode_attr = failure_propagation_mode + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode_attr, + root=root, + extra_bindings=extra_bindings, + ) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] class YieldOp: - - def __init__( - self, - operands: Optional[Union[Operation, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - if operands is None: - operands = [] - super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py index 5a695d621..2f6513199 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/dump_oplib.py @@ -31,61 +31,60 @@ def create_arg_parser() -> argparse.ArgumentParser: - p = argparse.ArgumentParser(description="Dump an oplib in various formats") - p.add_argument("modules", - metavar="M", - type=str, - nargs="*", - help="Op module to dump") - p.add_argument("--file", - metavar="F", - type=str, - nargs="*", - help="Python op file to dump") - p.add_argument("--format", - type=str, - dest="format", - default="yaml", - choices=("yaml", "repr"), - help="Format in which to dump") - return p + p = argparse.ArgumentParser(description="Dump an oplib in various formats") + p.add_argument( + "modules", metavar="M", type=str, nargs="*", help="Op module to dump" + ) + p.add_argument( + "--file", metavar="F", type=str, nargs="*", help="Python op file to dump" + ) + p.add_argument( + "--format", + type=str, + dest="format", + default="yaml", + choices=("yaml", "repr"), + help="Format in which to dump", + ) + return p def load_module_from_file(module_name, file_path): - spec = importlib.util.spec_from_file_location(module_name, file_path) - m = importlib.util.module_from_spec(spec) - spec.loader.exec_module(m) - return m + spec = importlib.util.spec_from_file_location(module_name, file_path) + m = importlib.util.module_from_spec(spec) + spec.loader.exec_module(m) + return m def main(args): - # Load all configs. - configs = [] - modules = [] - for module_name in args.modules: - modules.append( - importlib.import_module(module_name, - package="mlir.dialects.linalg.opdsl")) - for i, file_path in enumerate(args.file or []): - modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) - for m in modules: - for attr_name, value in m.__dict__.items(): - # TODO: This class layering is awkward. - if isinstance(value, DefinedOpCallable): - try: - linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) - except Exception as e: - raise ValueError( - f"Could not create LinalgOpConfig from {value.op_def}") from e - configs.extend(linalg_config) - - # Print. - if args.format == "yaml": - print(yaml_dump_all(configs)) - elif args.format == "repr": - for config in configs: - print(repr(config)) + # Load all configs. + configs = [] + modules = [] + for module_name in args.modules: + modules.append( + importlib.import_module(module_name, package="mlir.dialects.linalg.opdsl") + ) + for i, file_path in enumerate(args.file or []): + modules.append(load_module_from_file(f"_mlir_eval_oplib{i}", file_path)) + for m in modules: + for attr_name, value in m.__dict__.items(): + # TODO: This class layering is awkward. + if isinstance(value, DefinedOpCallable): + try: + linalg_config = LinalgOpConfig.from_linalg_op_def(value.op_def) + except Exception as e: + raise ValueError( + f"Could not create LinalgOpConfig from {value.op_def}" + ) from e + configs.extend(linalg_config) + + # Print. + if args.format == "yaml": + print(yaml_dump_all(configs)) + elif args.format == "repr": + for config in configs: + print(repr(config)) if __name__ == "__main__": - main(create_arg_parser().parse_args()) + main(create_arg_parser().parse_args()) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py index 038f06834..9fa626dfa 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/affine.py @@ -66,201 +66,201 @@ class AffineBuildState: - """Internal state for the AffineExprDef._create impls. - - Note that a "local" AffineBuildState can be created relative to a "global" - AffineBuildState. In that case, any affine expressions built will inherit - symbol and dim bindings from the global state and will update both as new - ones are discovered. This allows for building expressions across contexts - which share a common symbol and dim space. - """ - - def __init__(self, - *, - global_state: "AffineBuildState" = None, - allow_new_symbols: bool = True, - allow_new_dims: bool = True): - if not global_state: - self.all_symbols = dict() # type: Dict[str, int] - self.all_dims = dict() # type: Dict[str, int] - else: - # Alias the global dict. - self.all_symbols = global_state.all_symbols - self.all_dims = global_state.all_dims - - # Map of symbols and dims in the current build. - self.local_symbols = dict() # type: Dict[str, int] - self.local_dims = dict() # type: Dict[str, int] - self.allow_new_symbols = allow_new_symbols - self.allow_new_dims = allow_new_dims - - def get_dim(self, dimname: str) -> int: - """Gets the dim position given a name.""" - pos = self.all_dims.get(dimname) - if pos is None: - if not self.allow_new_dims: - raise ValueError( - f"New dimensions not allowed in the current affine expression: " - f"Requested '{dimname}', Availble: {self.all_dims}") - pos = len(self.all_dims) - self.all_dims[dimname] = pos - self.local_dims[dimname] = pos - return pos - - def get_symbol(self, symname: str) -> int: - """Geta a symbol position given a name.""" - pos = self.all_symbols.get(symname) - if pos is None: - if not self.allow_new_symbols: - raise ValueError( - f"New symbols not allowed in the current affine expression: " - f"Requested '{symname}', Availble: {self.all_symbols}") - pos = len(self.all_symbols) - self.all_symbols[symname] = pos - self.local_symbols[symname] = pos - return pos - - @property - def local_dim_count(self) -> int: - return len(self.local_dims) - - @property - def local_symbol_count(self) -> int: - return len(self.local_symbols) - - @property - def dim_count(self) -> int: - return len(self.all_dims) - - @property - def symbol_count(self) -> int: - return len(self.all_symbols) - - def __repr__(self): - lines = [f"AffineBuildState<"] - lines.append(f" symbols={self.local_symbols}") - lines.append(f" dims={self.local_dims}>") - return "\n".join(lines) + """Internal state for the AffineExprDef._create impls. + + Note that a "local" AffineBuildState can be created relative to a "global" + AffineBuildState. In that case, any affine expressions built will inherit + symbol and dim bindings from the global state and will update both as new + ones are discovered. This allows for building expressions across contexts + which share a common symbol and dim space. + """ + + def __init__( + self, + *, + global_state: "AffineBuildState" = None, + allow_new_symbols: bool = True, + allow_new_dims: bool = True, + ): + if not global_state: + self.all_symbols = dict() # type: Dict[str, int] + self.all_dims = dict() # type: Dict[str, int] + else: + # Alias the global dict. + self.all_symbols = global_state.all_symbols + self.all_dims = global_state.all_dims + + # Map of symbols and dims in the current build. + self.local_symbols = dict() # type: Dict[str, int] + self.local_dims = dict() # type: Dict[str, int] + self.allow_new_symbols = allow_new_symbols + self.allow_new_dims = allow_new_dims + + def get_dim(self, dimname: str) -> int: + """Gets the dim position given a name.""" + pos = self.all_dims.get(dimname) + if pos is None: + if not self.allow_new_dims: + raise ValueError( + f"New dimensions not allowed in the current affine expression: " + f"Requested '{dimname}', Availble: {self.all_dims}" + ) + pos = len(self.all_dims) + self.all_dims[dimname] = pos + self.local_dims[dimname] = pos + return pos + + def get_symbol(self, symname: str) -> int: + """Geta a symbol position given a name.""" + pos = self.all_symbols.get(symname) + if pos is None: + if not self.allow_new_symbols: + raise ValueError( + f"New symbols not allowed in the current affine expression: " + f"Requested '{symname}', Availble: {self.all_symbols}" + ) + pos = len(self.all_symbols) + self.all_symbols[symname] = pos + self.local_symbols[symname] = pos + return pos + + @property + def local_dim_count(self) -> int: + return len(self.local_dims) + + @property + def local_symbol_count(self) -> int: + return len(self.local_symbols) + + @property + def dim_count(self) -> int: + return len(self.all_dims) + + @property + def symbol_count(self) -> int: + return len(self.all_symbols) + + def __repr__(self): + lines = [f"AffineBuildState<"] + lines.append(f" symbols={self.local_symbols}") + lines.append(f" dims={self.local_dims}>") + return "\n".join(lines) class AffineExprDef: - """Base class for an affine expression being defined.""" + """Base class for an affine expression being defined.""" - def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: - """Builds the corresponding _ir.AffineExpr from the definitions. - """ - state = AffineBuildState() if state is None else state - expr = self._create(state) - return expr + def build(self, state: Optional[AffineBuildState] = None) -> _ir.AffineExpr: + """Builds the corresponding _ir.AffineExpr from the definitions.""" + state = AffineBuildState() if state is None else state + expr = self._create(state) + return expr - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - raise NotImplementedError() + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + raise NotImplementedError() - @staticmethod - def coerce_from(py_value): - if isinstance(py_value, int): - return AffineConstantExpr(py_value) - assert isinstance(py_value, AffineExprDef) - return py_value + @staticmethod + def coerce_from(py_value): + if isinstance(py_value, int): + return AffineConstantExpr(py_value) + assert isinstance(py_value, AffineExprDef) + return py_value - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - callback(self) + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + callback(self) - def __add__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) + def __add__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineAddExpr, lhs, rhs) - def __mul__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) + def __mul__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineMulExpr, lhs, rhs) - def __mod__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) + def __mod__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineModExpr, lhs, rhs) - def __floordiv__(lhs, rhs): - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) + def __floordiv__(lhs, rhs): + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineFloorDivExpr, lhs, rhs) - def __truediv__(lhs, rhs): - # TODO: Not really a ceil div - taking liberties for the DSL. - rhs = AffineExprDef.coerce_from(rhs) - return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) + def __truediv__(lhs, rhs): + # TODO: Not really a ceil div - taking liberties for the DSL. + rhs = AffineExprDef.coerce_from(rhs) + return AffineBinaryExprDef(_ir.AffineCeilDivExpr, lhs, rhs) class AffineConstantExpr(AffineExprDef): - """An affine constant being defined.""" + """An affine constant being defined.""" - def __init__(self, value: int): - assert isinstance(value, int) - self.value = value + def __init__(self, value: int): + assert isinstance(value, int) + self.value = value - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return _ir.AffineConstantExpr.get(self.value) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return _ir.AffineConstantExpr.get(self.value) - def __repr__(self): - return f"Const({self.value})" + def __repr__(self): + return f"Const({self.value})" class AffineBinaryExprDef(AffineExprDef): - """An affine binary expression being defined.""" + """An affine binary expression being defined.""" - def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): - self.ir_ctor = ir_ctor - self.lhs = lhs - self.rhs = rhs + def __init__(self, ir_ctor, lhs: AffineExprDef, rhs: AffineExprDef): + self.ir_ctor = ir_ctor + self.lhs = lhs + self.rhs = rhs - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + return self.ir_ctor.get(self.lhs._create(state), self.rhs._create(state)) - def visit_affine_exprs(self, callback): - """Visits all AffineExprDefs including self.""" - super().visit_affine_exprs(callback) - self.lhs.visit_affine_exprs(callback) - self.rhs.visit_affine_exprs(callback) + def visit_affine_exprs(self, callback): + """Visits all AffineExprDefs including self.""" + super().visit_affine_exprs(callback) + self.lhs.visit_affine_exprs(callback) + self.rhs.visit_affine_exprs(callback) - def __repr__(self): - return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" + def __repr__(self): + return f"{self.ir_ctor.__name__}({repr(self.lhs)}, {repr(self.rhs)})" class DimDef(AffineExprDef): - """Represents a named dimension. - - """ - ALL_DIMS = dict() # type: Dict[str, "DimDef"] - - def __new__(cls, dimname: str): - existing = cls.ALL_DIMS.get(dimname) - if existing is not None: - return existing - new = super().__new__(cls) - new.dimname = dimname - cls.ALL_DIMS[dimname] = new - return new - - def __repr__(self): - return f"Dim({self.dimname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_dim(self.dimname) - return _ir.AffineDimExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. - """ + """Represents a named dimension.""" + + ALL_DIMS = dict() # type: Dict[str, "DimDef"] + + def __new__(cls, dimname: str): + existing = cls.ALL_DIMS.get(dimname) + if existing is not None: + return existing + new = super().__new__(cls) + new.dimname = dimname + cls.ALL_DIMS[dimname] = new + return new - class ExpandoDims: + def __repr__(self): + return f"Dim({self.dimname})" - def __getattr__(self, n): - return cls(n) + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_dim(self.dimname) + return _ir.AffineDimExpr.get(position=pos) - return ExpandoDims() + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" + + class ExpandoDims: + def __getattr__(self, n): + return cls(n) + + return ExpandoDims() class SymbolDef(AffineExprDef): - """Represents a named symbol. + """Represents a named symbol. >>> s1 = SymbolDef("s1") >>> s1 @@ -270,36 +270,35 @@ class SymbolDef(AffineExprDef): False >>> s1 is SymbolDef("s1") True - """ - ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] - - def __new__(cls, symname: str): - existing = cls.ALL_SYMBOLS.get(symname) - if existing is not None: - return existing - new = super().__new__(cls) - new.symname = symname - cls.ALL_SYMBOLS[symname] = new - return new - - def __repr__(self): - return f"Symbol({self.symname})" - - def _create(self, state: AffineBuildState) -> _ir.AffineExpr: - pos = state.get_symbol(self.symname) - return _ir.AffineSymbolExpr.get(position=pos) - - @classmethod - def create_expando(cls): - """Create an expando class that creates unique symbols based on attr access. """ - class ExpandoSymbols: + ALL_SYMBOLS = dict() # type: Dict[str, "SymbolDef"] + + def __new__(cls, symname: str): + existing = cls.ALL_SYMBOLS.get(symname) + if existing is not None: + return existing + new = super().__new__(cls) + new.symname = symname + cls.ALL_SYMBOLS[symname] = new + return new + + def __repr__(self): + return f"Symbol({self.symname})" + + def _create(self, state: AffineBuildState) -> _ir.AffineExpr: + pos = state.get_symbol(self.symname) + return _ir.AffineSymbolExpr.get(position=pos) + + @classmethod + def create_expando(cls): + """Create an expando class that creates unique symbols based on attr access.""" - def __getattr__(self, n): - return cls(n) + class ExpandoSymbols: + def __getattr__(self, n): + return cls(n) - return ExpandoSymbols() + return ExpandoSymbols() # Global accessor for on-demand dims and symbols. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 135f55ea5..5d5866fde 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -23,223 +23,232 @@ class TensorExpression: - """An expression that can appear on the RHS of a comprehension.""" + """An expression that can appear on the RHS of a comprehension.""" - def to_scalar_expression(self) -> ScalarExpression: - raise NotImplementedError() + def to_scalar_expression(self) -> ScalarExpression: + raise NotImplementedError() - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - """Visits all tensor expression reachable by the expression.""" - callback(self) + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + """Visits all tensor expression reachable by the expression.""" + callback(self) - def collect_dim_uses(self, uses: Set["DimDef"]): - """Collects all DimDefs reachable through this expression.""" + def collect_dim_uses(self, uses: Set["DimDef"]): + """Collects all DimDefs reachable through this expression.""" - def visit_dim_def(dim_def: AffineExprDef): - if isinstance(dim_def, DimDef): - uses.add(dim_def) + def visit_dim_def(dim_def: AffineExprDef): + if isinstance(dim_def, DimDef): + uses.add(dim_def) - def visit_affine_exprs(expr: "TensorExpression"): - if isinstance(expr, TensorUse): - for ind in expr.indices: - ind.visit_affine_exprs(visit_dim_def) - if isinstance(expr, TensorReduceFn): - for ind in expr.reduce_fn.reduce_dims: - ind.visit_affine_exprs(visit_dim_def) + def visit_affine_exprs(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + for ind in expr.indices: + ind.visit_affine_exprs(visit_dim_def) + if isinstance(expr, TensorReduceFn): + for ind in expr.reduce_fn.reduce_dims: + ind.visit_affine_exprs(visit_dim_def) - self.visit_tensor_exprs(visit_affine_exprs) + self.visit_tensor_exprs(visit_affine_exprs) - def collect_tensor_uses(self, uses: Set["TensorUse"]): - """Collects all TensorUses reachable through this expression.""" + def collect_tensor_uses(self, uses: Set["TensorUse"]): + """Collects all TensorUses reachable through this expression.""" - def visit_tensor_use(expr: "TensorExpression"): - if isinstance(expr, TensorUse): - uses.add(expr) + def visit_tensor_use(expr: "TensorExpression"): + if isinstance(expr, TensorUse): + uses.add(expr) - self.visit_tensor_exprs(visit_tensor_use) + self.visit_tensor_exprs(visit_tensor_use) - def collect_indices(self, indices: Set["index"]): - """Collects all index accesses reachable through this expression.""" + def collect_indices(self, indices: Set["index"]): + """Collects all index accesses reachable through this expression.""" - def visit_index(expr: "TensorExpression"): - if isinstance(expr, index): - indices.add(expr) + def visit_index(expr: "TensorExpression"): + if isinstance(expr, index): + indices.add(expr) - self.visit_tensor_exprs(visit_index) + self.visit_tensor_exprs(visit_index) - def collect_scalar_uses(self, uses: Set["ScalarDef"]): - """Collects all ScalarDefs reachable through this expression.""" + def collect_scalar_uses(self, uses: Set["ScalarDef"]): + """Collects all ScalarDefs reachable through this expression.""" - def visit_scalar_def(expr: "TensorExpression"): - if isinstance(expr, ScalarDef): - uses.add(expr) + def visit_scalar_def(expr: "TensorExpression"): + if isinstance(expr, ScalarDef): + uses.add(expr) - self.visit_tensor_exprs(visit_scalar_def) + self.visit_tensor_exprs(visit_scalar_def) - def __add__(self, rhs: "TensorExpression") -> "TensorExpression": - return BinaryFn.add(self, rhs) + def __add__(self, rhs: "TensorExpression") -> "TensorExpression": + return BinaryFn.add(self, rhs) - def __mul__(self, rhs) -> "TensorExpression": - return BinaryFn.mul(self, rhs) + def __mul__(self, rhs) -> "TensorExpression": + return BinaryFn.mul(self, rhs) - def __sub__(self, rhs) -> "TensorExpression": - return BinaryFn.sub(self, rhs) + def __sub__(self, rhs) -> "TensorExpression": + return BinaryFn.sub(self, rhs) - def __hash__(self): - return hash(id(self)) + def __hash__(self): + return hash(id(self)) class TensorUse(TensorExpression): - """A used tensor represented by its (tensor_name, indices). - - Note that forming a comprehension via direct assignment is performed through - __setitem__ on the TensorDef level. However, performing a reduction with - compound ops (+=, *=, etc) is done by doing a: - TensorDef.__getitem__ - TensorUse.__iadd__ - TensorDef.__setitem__ - """ - - def __init__(self, operand_def: "OperandDef", - indices: Sequence[AffineExprDef]): - self.operand_def = operand_def - self.indices = tuple(indices) - - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArg(self.tensor_name).expr() - - @property - def tensor_name(self) -> str: - name = self.operand_def.name - assert name is not None, "TensorDef not registered with an op" - return name - - def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: - # Computes the reduction dims for implicit reductions. Assumes that the rhs - # is the expression being reduced and self is being reduced into. Any - # indices referenced on the rhs and not in self are considered reduction - # dims and will be ordered as encountered on the rhs. - rhs_dims = set() - lhs_dims = set() - rhs.collect_dim_uses(rhs_dims) - self.collect_dim_uses(lhs_dims) - return rhs_dims - lhs_dims - - def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": - return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) - - def __repr__(self): - return (f"{self.operand_def.name}" - f"[{', '.join([repr(i) for i in self.indices])}]") + """A used tensor represented by its (tensor_name, indices). + + Note that forming a comprehension via direct assignment is performed through + __setitem__ on the TensorDef level. However, performing a reduction with + compound ops (+=, *=, etc) is done by doing a: + TensorDef.__getitem__ + TensorUse.__iadd__ + TensorDef.__setitem__ + """ + + def __init__(self, operand_def: "OperandDef", indices: Sequence[AffineExprDef]): + self.operand_def = operand_def + self.indices = tuple(indices) + + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.tensor_name).expr() + + @property + def tensor_name(self) -> str: + name = self.operand_def.name + assert name is not None, "TensorDef not registered with an op" + return name + + def _compute_reduce_dims(self, rhs: TensorExpression) -> Set[DimDef]: + # Computes the reduction dims for implicit reductions. Assumes that the rhs + # is the expression being reduced and self is being reduced into. Any + # indices referenced on the rhs and not in self are considered reduction + # dims and will be ordered as encountered on the rhs. + rhs_dims = set() + lhs_dims = set() + rhs.collect_dim_uses(rhs_dims) + self.collect_dim_uses(lhs_dims) + return rhs_dims - lhs_dims + + def __iadd__(self, rhs: TensorExpression) -> "TensorReduceFn": + return ReduceFnUse(BinaryFn.add, None, *self._compute_reduce_dims(rhs))(rhs) + + def __repr__(self): + return ( + f"{self.operand_def.name}" f"[{', '.join([repr(i) for i in self.indices])}]" + ) class TensorFn(TensorExpression): - """Application of a tensor function.""" - - def __init__(self, kind: "FunctionKind", name: Optional[str], - operand_def: Optional["OperandDef"], type_var: Optional[TypeVar], - args: Sequence[TensorExpression]): - if bool(name) + bool(operand_def) != 1: - raise ValueError("One of 'name', 'operand_def' must be specified") - self.name = name - self.kind = kind - self.operand_def = operand_def - self.type_var = type_var - self.args = args - - def to_scalar_expression(self) -> ScalarExpression: - if self.operand_def: - assert self.operand_def.name, "TensorFn not registered with an op" - attr_name = self.operand_def.name if self.operand_def else None - args = [arg.to_scalar_expression() for arg in self.args] - return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - super().visit_tensor_exprs(callback) - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - name = self.operand_def.name if self.operand_def else self.name - return (f"{self.kind.name}.{name}(type_var={self.type_var}, " - f"args={', '.join(repr(a) for a in self.args)})") + """Application of a tensor function.""" + + def __init__( + self, + kind: "FunctionKind", + name: Optional[str], + operand_def: Optional["OperandDef"], + type_var: Optional[TypeVar], + args: Sequence[TensorExpression], + ): + if bool(name) + bool(operand_def) != 1: + raise ValueError("One of 'name', 'operand_def' must be specified") + self.name = name + self.kind = kind + self.operand_def = operand_def + self.type_var = type_var + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.operand_def: + assert self.operand_def.name, "TensorFn not registered with an op" + attr_name = self.operand_def.name if self.operand_def else None + args = [arg.to_scalar_expression() for arg in self.args] + return ScalarFn(self.kind, self.name, attr_name, self.type_var, args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + super().visit_tensor_exprs(callback) + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + name = self.operand_def.name if self.operand_def else self.name + return ( + f"{self.kind.name}.{name}(type_var={self.type_var}, " + f"args={', '.join(repr(a) for a in self.args)})" + ) class TensorReduceFn(TensorExpression): - """Application of a reduction function. - - This captures the lhs (initial value) separately from the rhs. - """ - - def __init__(self, reduce_use: "ReduceFnUse", - args: Sequence[TensorExpression]): - self.reduce_use = reduce_use - self.lhs = None # type: Optional[TensorUse] - self.args = args - - def to_scalar_expression(self) -> ScalarExpression: - if self.lhs is None: - raise ValueError(f"Cannot scalarize a TensorReduceFn that has not been " - f"bound to its lhs: {self}") - full_args = [self.lhs.to_scalar_expression() - ] + [arg.to_scalar_expression() for arg in self.args] - fn_name = None - attr_name = None - if self.reduce_use.binary_fn: - fn_name = self.reduce_use.binary_fn.fn_name - if self.reduce_use.binary_attr: - attr_name = self.reduce_use.binary_attr.operand_def.name - return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, - full_args).expr() - - def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): - for arg in self.args: - arg.visit_tensor_exprs(callback) - - def __repr__(self): - return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" + """Application of a reduction function. + + This captures the lhs (initial value) separately from the rhs. + """ + + def __init__(self, reduce_use: "ReduceFnUse", args: Sequence[TensorExpression]): + self.reduce_use = reduce_use + self.lhs = None # type: Optional[TensorUse] + self.args = args + + def to_scalar_expression(self) -> ScalarExpression: + if self.lhs is None: + raise ValueError( + f"Cannot scalarize a TensorReduceFn that has not been " + f"bound to its lhs: {self}" + ) + full_args = [self.lhs.to_scalar_expression()] + [ + arg.to_scalar_expression() for arg in self.args + ] + fn_name = None + attr_name = None + if self.reduce_use.binary_fn: + fn_name = self.reduce_use.binary_fn.fn_name + if self.reduce_use.binary_attr: + attr_name = self.reduce_use.binary_attr.operand_def.name + return ScalarFn(FunctionKind.BINARY, fn_name, attr_name, None, full_args).expr() + + def visit_tensor_exprs(self, callback: Callable[["TensorExpression"], None]): + for arg in self.args: + arg.visit_tensor_exprs(callback) + + def __repr__(self): + return f"{repr(self.reduce_use)}({', '.join(repr(a) for a in self.args)})" class const(TensorExpression): - """Returns the given constant floating point or integer value.""" + """Returns the given constant floating point or integer value.""" - def __init__(self, value: Any): - with _ir.Context(): - if isinstance(value, float): - self.value = str(_ir.FloatAttr.get_f64(float(value))) - elif isinstance(value, int): - self.value = str( - _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value))) - else: - raise ValueError(f"const requires int or float but got {type(value)}") + def __init__(self, value: Any): + with _ir.Context(): + if isinstance(value, float): + self.value = str(_ir.FloatAttr.get_f64(float(value))) + elif isinstance(value, int): + self.value = str( + _ir.IntegerAttr.get(_ir.IntegerType.get_signless(64), int(value)) + ) + else: + raise ValueError(f"const requires int or float but got {type(value)}") - def to_scalar_expression(self) -> ScalarExpression: - return ScalarConst(self.value).expr() + def to_scalar_expression(self) -> ScalarExpression: + return ScalarConst(self.value).expr() - def __repr__(self): - return f"const({self.value})" + def __repr__(self): + return f"const({self.value})" class index(TensorExpression): - """Returns the iteration index for a given dimension name. + """Returns the iteration index for a given dimension name. - Resolves the given dimension name to obtain its position in the iteration - domain of the operation. - """ + Resolves the given dimension name to obtain its position in the iteration + domain of the operation. + """ - def __init__(self, dim: DimDef): - self.dim_def = dim - self.dim = -1 + def __init__(self, dim: DimDef): + self.dim_def = dim + self.dim = -1 - def resolve_dimension_name(self, affine_state: AffineBuildState): - self.dim = affine_state.get_dim(self.dim_def.dimname) + def resolve_dimension_name(self, affine_state: AffineBuildState): + self.dim = affine_state.get_dim(self.dim_def.dimname) - def to_scalar_expression(self) -> ScalarExpression: - assert self.dim != -1, "Dimension name not resolved" - return ScalarIndex(self.dim).expr() + def to_scalar_expression(self) -> ScalarExpression: + assert self.dim != -1, "Dimension name not resolved" + return ScalarIndex(self.dim).expr() - def __repr__(self): - return f"index({repr(self.dim)})" + def __repr__(self): + return f"index({repr(self.dim)})" ############################################################################### @@ -248,155 +257,160 @@ def __repr__(self): class FunctionKind(Enum): - UNARY = 0 - BINARY = 1 - TYPE = 2 + UNARY = 0 + BINARY = 1 + TYPE = 2 class UnaryFnType: - """Unary function. + """Unary function. - A unary function takes one tensor expression and returns the - function evaluation result. - """ + A unary function takes one tensor expression and returns the + function evaluation result. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) + def __call__(self, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.UNARY, self.fn_name, None, None, [arg]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class UnaryFn: - """Unary function namespace.""" - exp = UnaryFnType("exp") - log = UnaryFnType("log") - abs = UnaryFnType("abs") - ceil = UnaryFnType("ceil") - floor = UnaryFnType("floor") - negf = UnaryFnType("negf") + """Unary function namespace.""" + + exp = UnaryFnType("exp") + log = UnaryFnType("log") + abs = UnaryFnType("abs") + ceil = UnaryFnType("ceil") + floor = UnaryFnType("floor") + negf = UnaryFnType("negf") class BinaryFnType: - """Binary function. + """Binary function. - A binary function takes two tensor expressions and returns the - function evaluation result. - """ + A binary function takes two tensor expressions and returns the + function evaluation result. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, arg0: TensorExpression, - arg1: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.BINARY, self.fn_name, None, None, [arg0, arg1]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class BinaryFn: - """Binary function namespace. + """Binary function namespace. - As the integer types are signless, signedness is implement by different - functions that treat integers as signed or unsigned values. + As the integer types are signless, signedness is implement by different + functions that treat integers as signed or unsigned values. + + Examples: + - max -> `arith.MaxSIOp` + - max_unsinged -> `arith.MaxUIOp` + """ - Examples: - - max -> `arith.MaxSIOp` - - max_unsinged -> `arith.MaxUIOp` - """ - add = BinaryFnType("add") - sub = BinaryFnType("sub") - mul = BinaryFnType("mul") - max_signed = BinaryFnType("max_signed") - min_signed = BinaryFnType("min_signed") - max_unsigned = BinaryFnType("max_unsigned") - min_unsigned = BinaryFnType("min_unsigned") + add = BinaryFnType("add") + sub = BinaryFnType("sub") + mul = BinaryFnType("mul") + max_signed = BinaryFnType("max_signed") + min_signed = BinaryFnType("min_signed") + max_unsigned = BinaryFnType("max_unsigned") + min_unsigned = BinaryFnType("min_unsigned") class TypeFnType: - """Type conversion function. + """Type conversion function. - A type conversion function takes a target type and a tensor expression and - returns the casted tensor expression. - """ + A type conversion function takes a target type and a tensor expression and + returns the casted tensor expression. + """ - def __init__(self, fn_name: str): - self.fn_name = fn_name + def __init__(self, fn_name: str): + self.fn_name = fn_name - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": - return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> "TensorFn": + return TensorFn(FunctionKind.TYPE, self.fn_name, None, type_var, [arg]) - def __repr__(self): - return f"{self.fn_name}" + def __repr__(self): + return f"{self.fn_name}" class TypeFn: - """Type conversion function namespace. + """Type conversion function namespace. + + As the integer types are signless, signedness is implement by different cast + functions that treat integers as signed (`cast_signed`) or unsigned + (`cast_unsigned`) values. - As the integer types are signless, signedness is implement by different cast - functions that treat integers as signed (`cast_signed`) or unsigned - (`cast_unsigned`) values. + Examples: + - cast_signed(I32 -> I64) -> `arith.ExtSIOp` + - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` + """ - Examples: - - cast_signed(I32 -> I64) -> `arith.ExtSIOp` - - cast_unsigned(I32 -> I64) -> `arith.ExtUIOp` - """ - cast_signed = TypeFnType("cast_signed") - cast_unsigned = TypeFnType("cast_unsigned") + cast_signed = TypeFnType("cast_signed") + cast_unsigned = TypeFnType("cast_unsigned") class ReduceFnUse: - """Reduction function use. + """Reduction function use. - A reduction use specifies the reduction function and dimensions. - """ + A reduction use specifies the reduction function and dimensions. + """ - def __init__(self, binary_fn: Optional[BinaryFnType], - binary_attr: Optional["BinaryFnAttrDef"], *reduce_dims: DimDef): - if bool(binary_fn) + bool(binary_attr) != 1: - raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") - self.binary_fn = binary_fn - self.binary_attr = binary_attr - self.reduce_dims = reduce_dims + def __init__( + self, + binary_fn: Optional[BinaryFnType], + binary_attr: Optional["BinaryFnAttrDef"], + *reduce_dims: DimDef, + ): + if bool(binary_fn) + bool(binary_attr) != 1: + raise ValueError("One of 'binary_fn', 'binary_attr' must be specified") + self.binary_fn = binary_fn + self.binary_attr = binary_attr + self.reduce_dims = reduce_dims - def __call__(self, *args: TensorExpression) -> "TensorReduceFn": - return TensorReduceFn(self, args) + def __call__(self, *args: TensorExpression) -> "TensorReduceFn": + return TensorReduceFn(self, args) - def __repr__(self): - fn = self.binary_fn if self.binary_fn else self.binary_attr - return ( - f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})") + def __repr__(self): + fn = self.binary_fn if self.binary_fn else self.binary_attr + return f"reduce_{repr(fn)}({', '.join(repr(d) for d in self.reduce_dims)})" class ReduceFnType: - """Reduction function. + """Reduction function. - A binary function that reduces its RHS into its LHS. - """ + A binary function that reduces its RHS into its LHS. + """ - def __init__(self, binary_fn: BinaryFnType): - if not isinstance(binary_fn, BinaryFnType): - raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") - self.binary_fn = binary_fn + def __init__(self, binary_fn: BinaryFnType): + if not isinstance(binary_fn, BinaryFnType): + raise ValueError(f"Reduce expected a BinaryFnType but got {binary_fn}") + self.binary_fn = binary_fn - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(self.binary_fn, None, *reduce_dims) + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(self.binary_fn, None, *reduce_dims) - def __repr__(self): - return f"reduce_{repr(self.binary_fn)}" + def __repr__(self): + return f"reduce_{repr(self.binary_fn)}" class ReduceFn: - add = ReduceFnType(BinaryFn.add) - mul = ReduceFnType(BinaryFn.mul) - max_signed = ReduceFnType(BinaryFn.max_signed) - min_signed = ReduceFnType(BinaryFn.min_signed) - max_unsigned = ReduceFnType(BinaryFn.max_unsigned) - min_unsigned = ReduceFnType(BinaryFn.min_unsigned) + add = ReduceFnType(BinaryFn.add) + mul = ReduceFnType(BinaryFn.mul) + max_signed = ReduceFnType(BinaryFn.max_signed) + min_signed = ReduceFnType(BinaryFn.min_signed) + max_unsigned = ReduceFnType(BinaryFn.max_unsigned) + min_unsigned = ReduceFnType(BinaryFn.min_unsigned) ############################################################################### @@ -405,237 +419,265 @@ class ReduceFn: class OperandKind(Enum): - INPUT_TENSOR = 0 - SCALAR = 1 - OUTPUT_TENSOR = 2 - INDEX_ATTR = 3 - UNARY_FN_ATTR = 4 - BINARY_FN_ATTR = 5 - TYPE_FN_ATTR = 6 + INPUT_TENSOR = 0 + SCALAR = 1 + OUTPUT_TENSOR = 2 + INDEX_ATTR = 3 + UNARY_FN_ATTR = 4 + BINARY_FN_ATTR = 5 + TYPE_FN_ATTR = 6 class OperandDef: - """Definition of an operand passed to an operation. - - Keep the meta information of Tensor, Scalar, and Attribute operands and - provide the shared registration functionality. - """ - - def __init__(self, - kind: OperandKind, - type_var: Optional[TypeVar] = None, - size_exprs: Optional[Sequence[AffineExprDef]] = None, - index_dims: Optional[Sequence[DimDef]] = None, - default_indices: Optional[Sequence[int]] = None, - default_fn: Optional[str] = None): - if type_var and not isinstance(type_var, TypeVar): - raise ValueError( - f"OperandDef requires a TypeVar but got {repr(type_var)}") - self.owner = None # type: Optional["LinalgOpDef"] - self.type_var = type_var - self.size_exprs = size_exprs - self.index_dims = index_dims - self.default_indices = default_indices - self.default_fn = default_fn - self.kind = kind - self.name = None # type: Optional[str] - self.registered_index = -1 # type: int - - def attach(self, index: int, name: str, owner: "LinalgOpDef"): - if self.owner: - raise ValueError(f"OperandDef already registered with an op: {self}") - self.registered_index = index - self.name = name - self.owner = owner - - def is_input(self) -> bool: - return (self.kind == OperandKind.SCALAR or - self.kind == OperandKind.INPUT_TENSOR) - - def is_tensor(self) -> bool: - return (self.kind == OperandKind.INPUT_TENSOR or - self.kind == OperandKind.OUTPUT_TENSOR) - - def is_attribute(self) -> bool: - return (self.kind == OperandKind.INDEX_ATTR or - self.kind == OperandKind.UNARY_FN_ATTR or - self.kind == OperandKind.BINARY_FN_ATTR or - self.kind == OperandKind.TYPE_FN_ATTR) - - def __hash__(self): - return hash(id(self)) - - def __repr__(self): - return (f"{self.name}:OperandDef(kind={self.kind.name}, " + """Definition of an operand passed to an operation. + + Keep the meta information of Tensor, Scalar, and Attribute operands and + provide the shared registration functionality. + """ + + def __init__( + self, + kind: OperandKind, + type_var: Optional[TypeVar] = None, + size_exprs: Optional[Sequence[AffineExprDef]] = None, + index_dims: Optional[Sequence[DimDef]] = None, + default_indices: Optional[Sequence[int]] = None, + default_fn: Optional[str] = None, + ): + if type_var and not isinstance(type_var, TypeVar): + raise ValueError(f"OperandDef requires a TypeVar but got {repr(type_var)}") + self.owner = None # type: Optional["LinalgOpDef"] + self.type_var = type_var + self.size_exprs = size_exprs + self.index_dims = index_dims + self.default_indices = default_indices + self.default_fn = default_fn + self.kind = kind + self.name = None # type: Optional[str] + self.registered_index = -1 # type: int + + def attach(self, index: int, name: str, owner: "LinalgOpDef"): + if self.owner: + raise ValueError(f"OperandDef already registered with an op: {self}") + self.registered_index = index + self.name = name + self.owner = owner + + def is_input(self) -> bool: + return self.kind == OperandKind.SCALAR or self.kind == OperandKind.INPUT_TENSOR + + def is_tensor(self) -> bool: + return ( + self.kind == OperandKind.INPUT_TENSOR + or self.kind == OperandKind.OUTPUT_TENSOR + ) + + def is_attribute(self) -> bool: + return ( + self.kind == OperandKind.INDEX_ATTR + or self.kind == OperandKind.UNARY_FN_ATTR + or self.kind == OperandKind.BINARY_FN_ATTR + or self.kind == OperandKind.TYPE_FN_ATTR + ) + + def __hash__(self): + return hash(id(self)) + + def __repr__(self): + return ( + f"{self.name}:OperandDef(kind={self.kind.name}, " f"type={repr(self.type_var)}, size_exprs={self.size_exprs}, " f"index_dims={self.index_dims}, " f"default_indices={self.default_indices}, " - f"default_fn={self.default_fn})") + f"default_fn={self.default_fn})" + ) class TensorDef: - """Tensor operand definition. - - Tensor operands are indexed using the associated indexing_map when forwarded - to the body of the structured op. A unique name identifies the tensor operands - and an index determines their position in the operation's parameter list. A - tensor definition takes type, a shape, and an optional flag to mark output - tensors. Additionally, a tuple of index dimensions may be used to map the - tensor to the loop dimensions of the operation. This mapping is needed to - compute the indexing map of shape-only tensors that have no uses. - """ - - def __init__(self, - type_var: TypeVar, - *shape: AffineExprDef, - index_dims: Optional[Sequence[DimDef]] = None, - output: bool = False): - if index_dims and len(shape) != len(index_dims): - raise ValueError(f"Expected the shape rank {len(shape)} to match the " - f"number of index_dims {len(index_dims)}") - if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): - raise ValueError(f"TensorDef requires index dims of type DimDef but " - f"got {index_dims}") - kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR - self.operand_def = OperandDef( - kind, type_var=type_var, size_exprs=shape, index_dims=index_dims) - - def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: - assert self.operand_def.owner, "TensorDef is not registered with an op" - state = AffineBuildState( - global_state=self.operand_def.owner._affine_state, - allow_new_symbols=False) - if not isinstance(dims, tuple): - dims = (dims,) # Handle single subscript case. - # Special case: (None) is a 0d-scalar use. - if dims == (None,): - dims = () - - exprs = [] - for expr_def in dims: - if not isinstance(expr_def, AffineExprDef): - raise KeyError( - "A TensorDef can only be subscripted by a tuple of affine dims") - exprs.append(expr_def) - return TensorUse(self.operand_def, exprs) - - def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): - """Creates a new 1:1 comprehension by binding this tensor to an expression. - - Note that due to the way assignment works in Python, we have to capture - direct assignment as a setitem on the TensorDef. + """Tensor operand definition. + + Tensor operands are indexed using the associated indexing_map when forwarded + to the body of the structured op. A unique name identifies the tensor operands + and an index determines their position in the operation's parameter list. A + tensor definition takes type, a shape, and an optional flag to mark output + tensors. Additionally, a tuple of index dimensions may be used to map the + tensor to the loop dimensions of the operation. This mapping is needed to + compute the indexing map of shape-only tensors that have no uses. """ - if not isinstance(value, TensorExpression): - raise ValueError(f"Only TensorExpressions can be assigned to TensorDefs. " - f"Got: {repr(value)}") - use = self[dims] - comp = Comprehension((use, value)) - self.operand_def.owner.comprehensions.append(comp) + + def __init__( + self, + type_var: TypeVar, + *shape: AffineExprDef, + index_dims: Optional[Sequence[DimDef]] = None, + output: bool = False, + ): + if index_dims and len(shape) != len(index_dims): + raise ValueError( + f"Expected the shape rank {len(shape)} to match the " + f"number of index_dims {len(index_dims)}" + ) + if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims): + raise ValueError( + f"TensorDef requires index dims of type DimDef but " f"got {index_dims}" + ) + kind = OperandKind.OUTPUT_TENSOR if output else OperandKind.INPUT_TENSOR + self.operand_def = OperandDef( + kind, type_var=type_var, size_exprs=shape, index_dims=index_dims + ) + + def __getitem__(self, dims: Sequence[AffineExprDef]) -> TensorUse: + assert self.operand_def.owner, "TensorDef is not registered with an op" + state = AffineBuildState( + global_state=self.operand_def.owner._affine_state, allow_new_symbols=False + ) + if not isinstance(dims, tuple): + dims = (dims,) # Handle single subscript case. + # Special case: (None) is a 0d-scalar use. + if dims == (None,): + dims = () + + exprs = [] + for expr_def in dims: + if not isinstance(expr_def, AffineExprDef): + raise KeyError( + "A TensorDef can only be subscripted by a tuple of affine dims" + ) + exprs.append(expr_def) + return TensorUse(self.operand_def, exprs) + + def __setitem__(self, dims: Sequence[AffineExprDef], value: TensorExpression): + """Creates a new 1:1 comprehension by binding this tensor to an expression. + + Note that due to the way assignment works in Python, we have to capture + direct assignment as a setitem on the TensorDef. + """ + if not isinstance(value, TensorExpression): + raise ValueError( + f"Only TensorExpressions can be assigned to TensorDefs. " + f"Got: {repr(value)}" + ) + use = self[dims] + comp = Comprehension((use, value)) + self.operand_def.owner.comprehensions.append(comp) class ScalarDef(TensorExpression): - """Scalar operand definition. + """Scalar operand definition. - Scalar operands are forwarded to the body of the structured op as they are. - A unique name identifies the scalars and an index determines their position in - the operation's parameter list. - """ + Scalar operands are forwarded to the body of the structured op as they are. + A unique name identifies the scalars and an index determines their position in + the operation's parameter list. + """ - def __init__(self, type_var: TypeVar): - self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) + def __init__(self, type_var: TypeVar): + self.operand_def = OperandDef(OperandKind.SCALAR, type_var=type_var) - @property - def scalar_name(self) -> str: - name = self.operand_def.name - assert name is not None, "ScalarDef not registered with an op" - return name + @property + def scalar_name(self) -> str: + name = self.operand_def.name + assert name is not None, "ScalarDef not registered with an op" + return name - def to_scalar_expression(self) -> ScalarExpression: - return ScalarArg(self.scalar_name).expr() + def to_scalar_expression(self) -> ScalarExpression: + return ScalarArg(self.scalar_name).expr() class IndexAttrDef: - """Index attribute definition. - - Index attributes provide a way to define and set symbols that can be used in - indexing expressions. Every attribute specifies a tuple of symbols that at - compile-time are replaced by integer values as well as their default values. - """ - - def __init__(self, *sizes: SymbolDef, default: Sequence[int]): - if any(not isinstance(size, SymbolDef) for size in sizes): - raise ValueError(f"IndexAttrDef requires sizes of type SymbolDef " - f"but got {sizes}") - if any(not isinstance(default_val, int) for default_val in default): - raise ValueError(f"IndexAttrDef requires default values of type int " - f"but got {default}") - if len(sizes) != len(default): - raise ValueError(f"IndexAttrDef expects {len(sizes)} default values " - f"but got {len(default)}") - self.operand_def = OperandDef( - OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default) + """Index attribute definition. + + Index attributes provide a way to define and set symbols that can be used in + indexing expressions. Every attribute specifies a tuple of symbols that at + compile-time are replaced by integer values as well as their default values. + """ + + def __init__(self, *sizes: SymbolDef, default: Sequence[int]): + if any(not isinstance(size, SymbolDef) for size in sizes): + raise ValueError( + f"IndexAttrDef requires sizes of type SymbolDef " f"but got {sizes}" + ) + if any(not isinstance(default_val, int) for default_val in default): + raise ValueError( + f"IndexAttrDef requires default values of type int " + f"but got {default}" + ) + if len(sizes) != len(default): + raise ValueError( + f"IndexAttrDef expects {len(sizes)} default values " + f"but got {len(default)}" + ) + self.operand_def = OperandDef( + OperandKind.INDEX_ATTR, size_exprs=sizes, default_indices=default + ) class UnaryFnAttrDef: - """Unary function attribute definition. + """Unary function attribute definition. - Unary function attributes provide a way to make the arithmetic computation - parametrizable. Every attribute specifies a default unary function - that may be overwritten at operation instantiation time. - """ + Unary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default unary function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "UnaryFnType"): - if not isinstance(default, UnaryFnType): - raise ValueError(f"UnaryFnAttrDef requires default of type UnaryFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "UnaryFnType"): + if not isinstance(default, UnaryFnType): + raise ValueError( + f"UnaryFnAttrDef requires default of type UnaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.UNARY_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, arg: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) + def __call__(self, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.UNARY, None, self.operand_def, None, [arg]) class BinaryFnAttrDef: - """Binary function attribute definition. + """Binary function attribute definition. - Binary function attributes provide a way to make the arithmetic computation - parametrizable. Every attribute specifies a default binary function - that may be overwritten at operation instantiation time. - """ + Binary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default binary function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "BinaryFnType"): - if not isinstance(default, BinaryFnType): - raise ValueError(f"BinaryFnAttrDef requires default of type BinaryFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "BinaryFnType"): + if not isinstance(default, BinaryFnType): + raise ValueError( + f"BinaryFnAttrDef requires default of type BinaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.BINARY_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, arg0: TensorExpression, - arg1: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, - [arg0, arg1]) + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.BINARY, None, self.operand_def, None, [arg0, arg1]) - def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: - return ReduceFnUse(None, self, *reduce_dims) + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) class TypeFnAttrDef: - """Type conversion function attribute definition. + """Type conversion function attribute definition. - Type conversion function attributes provide a way to make type conversions - parameterizable. Every attribute specifies a default type conversion function - that may be overwritten at operation instantiation time. - """ + Type conversion function attributes provide a way to make type conversions + parameterizable. Every attribute specifies a default type conversion function + that may be overwritten at operation instantiation time. + """ - def __init__(self, default: "TypeFnType"): - if not isinstance(default, TypeFnType): - raise ValueError(f"TypeFnAttrDef requires default of type TypeFnType " - f"but got {default}") - self.operand_def = OperandDef( - OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name) + def __init__(self, default: "TypeFnType"): + if not isinstance(default, TypeFnType): + raise ValueError( + f"TypeFnAttrDef requires default of type TypeFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TYPE_FN_ATTR, default_fn=default.fn_name + ) - def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: - return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) + def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: + return TensorFn(FunctionKind.TYPE, None, self.operand_def, type_var, [arg]) ############################################################################### @@ -644,48 +686,48 @@ def __call__(self, type_var: TypeVar, arg: TensorExpression) -> TensorFn: class Comprehension: - """Represents a single comprehension.""" - - def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): - self.definitions = list() # List[TensorUse] - self.values = list() # List[TensorExpression] - - # Find the lhs to reduction rhs. - for assign, value in bindings: - if isinstance(value, TensorReduceFn): - if value.lhs: - raise ValueError(f"Reduction expression already assigns: {value}") - value.lhs = assign - self.definitions.append(assign) - self.values.append(value) - - @property - def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: - """Gets the reduction dims for the comprehension or None.""" - result = set() - for use in self.values: - if isinstance(use, TensorReduceFn): - result.add(use.reduce_use.reduce_dims) - else: - result.add(tuple()) - return result - - def __repr__(self): - if len(self.definitions) > 1: - defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" - values_repr = f"({', '.join(repr(v) for v in self.values)})" - else: - defs_repr = f"{repr(self.definitions[0])}" - values_repr = f"{repr(self.values[0])}" - - return f"{defs_repr} = {values_repr}" + """Represents a single comprehension.""" + + def __init__(self, *bindings: Tuple[TensorUse, TensorExpression]): + self.definitions = list() # List[TensorUse] + self.values = list() # List[TensorExpression] + + # Find the lhs to reduction rhs. + for assign, value in bindings: + if isinstance(value, TensorReduceFn): + if value.lhs: + raise ValueError(f"Reduction expression already assigns: {value}") + value.lhs = assign + self.definitions.append(assign) + self.values.append(value) + + @property + def all_reduction_dims(self) -> Set[Tuple[DimDef, ...]]: + """Gets the reduction dims for the comprehension or None.""" + result = set() + for use in self.values: + if isinstance(use, TensorReduceFn): + result.add(use.reduce_use.reduce_dims) + else: + result.add(tuple()) + return result + + def __repr__(self): + if len(self.definitions) > 1: + defs_repr = f"({', '.join(repr(d) for d in self.definitions)})" + values_repr = f"({', '.join(repr(v) for v in self.values)})" + else: + defs_repr = f"{repr(self.definitions[0])}" + values_repr = f"{repr(self.values[0])}" + + return f"{defs_repr} = {values_repr}" class OpInterfaceDef: - """An interface that an op implements.""" + """An interface that an op implements.""" - def __init__(self, cpp_name: str): - self.cpp_name = cpp_name + def __init__(self, cpp_name: str): + self.cpp_name = cpp_name ContractionOpInterface = OpInterfaceDef("LinalgContractionOpInterface") @@ -694,86 +736,94 @@ def __init__(self, cpp_name: str): class OpDefinitionDef: - """A method that an op implements.""" + """A method that an op implements.""" - def __init__(self, def_name: str): - self.def_name = def_name + def __init__(self, def_name: str): + self.def_name = def_name Canonicalizer = OpDefinitionDef("hasCanonicalizer") class OpMetadataDef(YAMLObject): - """Metadata about the op (generally not behavior impacting).""" - yaml_tag = "!LinalgOpMetadata" - - def __init__(self, name: str, cpp_class_name: Optional[str], - doc: Optional[str]): - self.name = name - self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name - self.doc = doc - self.implements = [] # type: List[OpInterfaceDef] - self.defines = [] # type: List[OpDefinitionsDef] - - def to_yaml_custom_dict(self): - d = dict( - name=self.name, - cpp_class_name=self.cpp_class_name, - doc=self.doc, - ) - if self.implements: - d["implements"] = [intr.cpp_name for intr in self.implements] - if self.defines: - d["defines"] = [defi.def_name for defi in self.defines] - return d + """Metadata about the op (generally not behavior impacting).""" + + yaml_tag = "!LinalgOpMetadata" + + def __init__(self, name: str, cpp_class_name: Optional[str], doc: Optional[str]): + self.name = name + self.cpp_class_name = cpp_class_name if cpp_class_name is not None else name + self.doc = doc + self.implements = [] # type: List[OpInterfaceDef] + self.defines = [] # type: List[OpDefinitionsDef] + + def to_yaml_custom_dict(self): + d = dict( + name=self.name, + cpp_class_name=self.cpp_class_name, + doc=self.doc, + ) + if self.implements: + d["implements"] = [intr.cpp_name for intr in self.implements] + if self.defines: + d["defines"] = [defi.def_name for defi in self.defines] + return d class LinalgOpDef: - """Definition of a linalg op.""" - - def __init__(self, - name: str, - cpp_class_name: Optional[str] = None, - doc: Optional[str] = None): - self.metadata = OpMetadataDef( - name=name, cpp_class_name=cpp_class_name, doc=doc) - self.registered_operands = dict() # type: Dict[str, OperandDef] - self.domain = list() # type: List[DimDef] - self.comprehensions = list() # type: List[Comprehension] - self._affine_state = AffineBuildState() - - def add_operand(self, name: str, operand: OperandDef): - """Registers an operand.""" - if name in self.registered_operands: - raise ValueError(f"The operand {name} is already registered " - f"to {self.registered_operands['name']}") - structured_op_methods = [ - "inputs", "outputs", "result_tensors", "region", "iterator_types", - "indexing_maps", "getRegionBuilder", "getLibraryCallName" - ] - if operand.is_attribute() and name in structured_op_methods: - raise ValueError(f"The attribute name {name} conflicts with a structured " - f"op method name") - # Ensure output tensors are registered after input tensors and scalars and - # attributes are registered after all other operand types. - if operand.is_input() and any( - not op_def.is_input() for op_def in self.registered_operands.values()): - raise ValueError(f"Input {name} registered after an output or attribute") - if operand.kind == OperandKind.OUTPUT_TENSOR and any( - op_def.is_attribute() for op_def in self.registered_operands.values()): - raise ValueError(f"Output {name} registered after an attribute") - operand.attach(len(self.registered_operands), name, self) - self.registered_operands[name] = operand - - def __repr__(self): - lines = [ - f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name}," - ] - for name, operand in self.registered_operands.items(): - lines.append(f" {operand}") - if self.comprehensions: - lines[-1] += " {" - for comprehension in self.comprehensions: - lines.append(f" {comprehension}") - lines.append("}") - return "\n".join(lines) + """Definition of a linalg op.""" + + def __init__( + self, name: str, cpp_class_name: Optional[str] = None, doc: Optional[str] = None + ): + self.metadata = OpMetadataDef(name=name, cpp_class_name=cpp_class_name, doc=doc) + self.registered_operands = dict() # type: Dict[str, OperandDef] + self.domain = list() # type: List[DimDef] + self.comprehensions = list() # type: List[Comprehension] + self._affine_state = AffineBuildState() + + def add_operand(self, name: str, operand: OperandDef): + """Registers an operand.""" + if name in self.registered_operands: + raise ValueError( + f"The operand {name} is already registered " + f"to {self.registered_operands['name']}" + ) + structured_op_methods = [ + "inputs", + "outputs", + "result_tensors", + "region", + "iterator_types", + "indexing_maps", + "getRegionBuilder", + "getLibraryCallName", + ] + if operand.is_attribute() and name in structured_op_methods: + raise ValueError( + f"The attribute name {name} conflicts with a structured " + f"op method name" + ) + # Ensure output tensors are registered after input tensors and scalars and + # attributes are registered after all other operand types. + if operand.is_input() and any( + not op_def.is_input() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Input {name} registered after an output or attribute") + if operand.kind == OperandKind.OUTPUT_TENSOR and any( + op_def.is_attribute() for op_def in self.registered_operands.values() + ): + raise ValueError(f"Output {name} registered after an attribute") + operand.attach(len(self.registered_operands), name, self) + self.registered_operands[name] = operand + + def __repr__(self): + lines = [f"LinalgOpDef({self.metadata.name} -> {self.metadata.cpp_class_name},"] + for name, operand in self.registered_operands.items(): + lines.append(f" {operand}") + if self.comprehensions: + lines[-1] += " {" + for comprehension in self.comprehensions: + lines.append(f" {comprehension}") + lines.append("}") + return "\n".join(lines) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py index 2a0da6829..d522d5712 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/config.py @@ -21,422 +21,468 @@ def _serialize_affine_map(affine_map: _ir.AffineMap) -> str: - with affine_map.context: - # Affine map printing/parsing is via an AffineMap attr. - attr = _ir.AffineMapAttr.get(affine_map) - return str(attr) + with affine_map.context: + # Affine map printing/parsing is via an AffineMap attr. + attr = _ir.AffineMapAttr.get(affine_map) + return str(attr) class TensorUseConfig: - """Wrapper around a TensorUse with additional context-bound state.""" + """Wrapper around a TensorUse with additional context-bound state.""" - def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): - self.tensor_use = tensor_use - self.indexing_map = indexing_map + def __init__(self, tensor_use: TensorUse, indexing_map: _ir.AffineMap): + self.tensor_use = tensor_use + self.indexing_map = indexing_map - def __repr__(self): - return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" + def __repr__(self): + return f"Use({self.tensor_use}, indexing_map={self.indexing_map})" class OperandDefConfig(YAMLObject): - """Wrapper containing an operand definition with additional state.""" - yaml_tag = "!LinalgOperandDefConfig" - - def __init__(self, - operand_def: OperandDef, - shape_map: Optional[_ir.AffineMap] = None, - index_attr_map: Optional[_ir.AffineMap] = None): - self.operand_def = operand_def - self.shape_map = shape_map # type: Optional[_ir.AffineMap] - self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] - self.indexing_map = None # type: Optional[_ir.AffineMap] - - @property - def name(self) -> str: - return self.operand_def.name - - @property - def kind(self) -> OperandKind: - return self.operand_def.kind - - @property - def type_var(self) -> TypeVar: - return self.operand_def.type_var - - def to_yaml_custom_dict(self): - self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) - if self.type_var: - self_dict["type_var"] = self.type_var.name - if self.shape_map: - self_dict["shape_map"] = _serialize_affine_map(self.shape_map) - if self.index_attr_map: - self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) - if self.operand_def.default_indices: - self_dict["default_indices"] = self.operand_def.default_indices - if self.operand_def.default_fn: - self_dict["default_fn"] = self.operand_def.default_fn - return self_dict - - def __repr__(self): - return (f"OperandDefConfig({self.operand_def}, " + """Wrapper containing an operand definition with additional state.""" + + yaml_tag = "!LinalgOperandDefConfig" + + def __init__( + self, + operand_def: OperandDef, + shape_map: Optional[_ir.AffineMap] = None, + index_attr_map: Optional[_ir.AffineMap] = None, + ): + self.operand_def = operand_def + self.shape_map = shape_map # type: Optional[_ir.AffineMap] + self.index_attr_map = index_attr_map # type: Optional[_ir.AffineMap] + self.indexing_map = None # type: Optional[_ir.AffineMap] + + @property + def name(self) -> str: + return self.operand_def.name + + @property + def kind(self) -> OperandKind: + return self.operand_def.kind + + @property + def type_var(self) -> TypeVar: + return self.operand_def.type_var + + def to_yaml_custom_dict(self): + self_dict = dict(name=self.name, kind=self.operand_def.kind.name.lower()) + if self.type_var: + self_dict["type_var"] = self.type_var.name + if self.shape_map: + self_dict["shape_map"] = _serialize_affine_map(self.shape_map) + if self.index_attr_map: + self_dict["index_attr_map"] = _serialize_affine_map(self.index_attr_map) + if self.operand_def.default_indices: + self_dict["default_indices"] = self.operand_def.default_indices + if self.operand_def.default_fn: + self_dict["default_fn"] = self.operand_def.default_fn + return self_dict + + def __repr__(self): + return ( + f"OperandDefConfig({self.operand_def}, " f"shape_map={self.shape_map}, " f"index_attr_map={self.index_attr_map}, " - f"indexing_map={self.indexing_map})") + f"indexing_map={self.indexing_map})" + ) class LinalgIndexingMapsConfig(YAMLObject): - """Abstracts the style of indexing maps that the op exports. - - Presently only static (tied to the op name) indexing maps are supported. In - the future, it is expected that we will have additional variants: - - Dynamic based on attributes - - Dynamic based on operands - Each is expected to require a different variant of specification. - """ - yaml_tag = "!LinalgIndexingMapsConfig" - - def __init__(self, - static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): - self.static_indexing_maps = static_indexing_maps - - def to_yaml_custom_dict(self): - if self.static_indexing_maps is not None: - return dict(static_indexing_maps=[ - _serialize_affine_map(m) for m in self.static_indexing_maps - ]) - raise ValueError( - f"LinalgIndexingMapsConfig must have one type of indexing map" - f"(got none)") + """Abstracts the style of indexing maps that the op exports. + Presently only static (tied to the op name) indexing maps are supported. In + the future, it is expected that we will have additional variants: + - Dynamic based on attributes + - Dynamic based on operands + Each is expected to require a different variant of specification. + """ -class LinalgStructuredOpConfig(YAMLObject): - """Configuration for metadata sufficient to construct a linalg named op.""" - - yaml_tag = "!LinalgStructuredOpConfig" - - def __init__(self, - comprehension: Comprehension, - domain: Sequence[DimDef], - registered_operands: Sequence[OperandDef], - context: Optional[_ir.Context] = None): - self.context = context if context is not None else _ir.Context() - self.affine_state = AffineBuildState() - self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] - self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] - self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] - - # Compute the ordered set of writes and collect the tensor, capture, dims, - # and index uses. - collected_tensor_uses = set() - collected_scalar_uses = set() - collected_dim_uses = set() - collected_indices = set() - for write_use, read_use in zip(comprehension.definitions, - comprehension.values): - self.writes.append((write_use, read_use)) - - for write_use, read_use in self.writes: - collected_tensor_uses.add(write_use) - read_use.collect_tensor_uses(collected_tensor_uses) - read_use.collect_scalar_uses(collected_scalar_uses) - read_use.collect_dim_uses(collected_dim_uses) - write_use.collect_dim_uses(collected_dim_uses) - read_use.collect_indices(collected_indices) - - # Set domain to the sorted list of uses if no domain annotation is given. - if not domain: - domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) - - # Verify the domain dimensions match the used dimensions. - if (len(domain) != len(collected_dim_uses) or - any(dim not in collected_dim_uses for dim in domain)): - raise ValueError(f"Expected the annotated domain dimensions {domain} to " - f"match the set of dimension used by the tensor " - f"comprehension {collected_dim_uses}") - - # Instantiate the dimensions in the given order. - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - for dim in domain: - dim.build(state=local_state) - - # Collect all attribute definitions. - collected_attr_defs = list() - for operand in registered_operands: - if operand.is_attribute(): - collected_attr_defs.append(operand) - - # Collect all tensors with manual indexing annotation. - collected_index_defs = list() - for operand in registered_operands: - if operand.index_dims: - if any(dim not in collected_dim_uses for dim in operand.index_dims): - raise ValueError(f"Expected all index dims {operand.index_dims} of " - f"operand {operand.name} to have uses.") - collected_index_defs.append(operand) - - # Collect the operand definitions of all tensor/scalar uses, attributes, and - # shape-only tensors. - all_operand_defs = list() - for use in collected_tensor_uses: - all_operand_defs.append(use.operand_def) - for use in collected_scalar_uses: - all_operand_defs.append(use.operand_def) - for definition in collected_attr_defs: - all_operand_defs.append(definition) - for definition in collected_index_defs: - all_operand_defs.append(definition) - - # Add all operands in registration order to ensure the symbols are - # registered in the order they appear. - all_operand_defs = sorted( - all_operand_defs, key=lambda operand_def: operand_def.registered_index) - for operand_def in all_operand_defs: - self.add_operand(operand_def) - - # Add all shape-only tensor index_dim annotations and all tensor uses. - for definition in collected_index_defs: - self.add_indexed_operand(definition) - for use in collected_tensor_uses: - self.add_tensor_use(use) - - # Normalize all shape and indexing maps now that full count of dims and - # symbols are known. - for cuse in self.uses.values(): - cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) - for definition in collected_index_defs: - self.operands[definition].indexing_map = self._normalize_affine_map( - self.operands[definition].indexing_map) - for operand_config in self.operands.values(): - if operand_config.shape_map: - operand_config.shape_map = self._normalize_affine_map( - operand_config.shape_map, with_dims=False) - if operand_config.index_attr_map: - operand_config.index_attr_map = self._normalize_affine_map( - operand_config.index_attr_map, with_dims=False) - - # Now for each write use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for write_use, _ in self.writes: - write_tensor_config = self.operands[write_use.operand_def] - if write_tensor_config.indexing_map: - raise ValueError( - f"Unexpected multi-write to a single tensor: {write_tensor_config}") - write_tensor_config.indexing_map = self.uses[write_use].indexing_map - - # For each read use, propagate the indexing maps from the use to the - # tensor, ensuring that there are not conflicts. - for _, read_expr in self.writes: - read_uses = set() # type: Set[TensorUse] - read_expr.collect_tensor_uses(read_uses) - for read_use in read_uses: - read_operand_config = self.operands[read_use.operand_def] - if (read_operand_config.indexing_map and - read_operand_config.indexing_map != - self.uses[read_use].indexing_map): - raise ValueError( - f"Unexpected multi-read of a tensor with different accesses:" - f"{read_operand_config} vs {read_use}") - read_operand_config.indexing_map = self.uses[read_use].indexing_map - - # Set the indexing map of all scalar uses to the empty map. - for operand_config in self.operands.values(): - if operand_config.operand_def.kind == OperandKind.SCALAR: - operand_config.indexing_map = self._get_scalar_map() - - # Check all registered tensor and scalar operands have an indexing map. - for operand in registered_operands: - if operand.is_attribute(): - continue - if not (operand in self.operands and self.operands[operand].indexing_map): - raise ValueError(f"Failed to compute an indexing map for operand " - f"{operand.name}") - - # Collect reduction dims and ensure all the same. - all_reduction_dims = set(comprehension.all_reduction_dims) - if len(all_reduction_dims) != 1: - raise ValueError( - f"All writes within a generic must have the same reduction " - f"dims. Got: {all_reduction_dims}") - self.reduction_dims = next(iter(all_reduction_dims)) - - # Check the index dimension exists and resolve. - for index in collected_indices: - if index.dim_def.dimname not in self.affine_state.all_dims: + yaml_tag = "!LinalgIndexingMapsConfig" + + def __init__(self, static_indexing_maps: Optional[Sequence[_ir.AffineMap]] = None): + self.static_indexing_maps = static_indexing_maps + + def to_yaml_custom_dict(self): + if self.static_indexing_maps is not None: + return dict( + static_indexing_maps=[ + _serialize_affine_map(m) for m in self.static_indexing_maps + ] + ) raise ValueError( - f"The dimension {index.dim_def.dimname} is not part of the " - f"iteration domain {self.affine_state.all_dims}") - index.resolve_dimension_name(self.affine_state) - - # Generate the scalar assignments (used to build a body). - self.assignments = [ - ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) - for write_use, read_expr in self.writes - ] - - @property - def ordered_operands(self) -> Sequence[OperandDefConfig]: - return sorted( - self.operands.values(), - key=lambda operand: operand.operand_def.registered_index) - - @property - def ordered_dims(self) -> Sequence[Tuple[str, int]]: - """Gets the ordered list of dim bindings (symbolic name, position). - - TODO: The original parser relies on parse ordering to arrive at the - iterator types, but that ordering is not defined on the Python side, so - this may be ambiguous. - """ - return list(self.affine_state.all_dims.items()) - - @property - def indexing_maps(self) -> Sequence[_ir.AffineMap]: - return [o.indexing_map for o in self.ordered_operands if o.indexing_map] - - @property - def iterator_types(self) -> Sequence[str]: - - def get_type(symbolic_name, position): - for reduction_dim_expr in self.reduction_dims: - if reduction_dim_expr.dimname == symbolic_name: - return "reduction" - return "parallel" - - return [get_type(*dim) for dim in self.ordered_dims] - - def add_operand(self, operand_def: OperandDef): - if operand_def in self.operands: - return - if not (operand_def.is_tensor() or - operand_def.kind == OperandKind.INDEX_ATTR): - self.operands[operand_def] = OperandDefConfig(operand_def) - return - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_dims=False) - exprs = [] - for expr in operand_def.size_exprs: - exprs.append(expr.build(state=local_state)) - assert local_state.local_dim_count == 0 - affine_map = _ir.AffineMap.get( - dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs) - if operand_def.kind == OperandKind.INDEX_ATTR: - self.operands[operand_def] = OperandDefConfig( - operand_def, index_attr_map=affine_map) - else: - self.operands[operand_def] = OperandDefConfig( - operand_def, shape_map=affine_map) - - def add_indexed_operand(self, operand_def: OperandDef): - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - exprs = [] - for expr in operand_def.index_dims: - exprs.append(expr.build(state=local_state)) - self.operands[operand_def].indexing_map = _ir.AffineMap.get( - dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) - - def add_tensor_use(self, tensor_use: TensorUse): - if tensor_use in self.uses: - return - with self.context: - local_state = AffineBuildState( - global_state=self.affine_state, allow_new_symbols=False) - exprs = [] - for expr in tensor_use.indices: - exprs.append(expr.build(state=local_state)) - indexing_map = _ir.AffineMap.get( - dim_count=local_state.dim_count, - symbol_count=local_state.symbol_count, - exprs=exprs) - - use_config = TensorUseConfig(tensor_use, indexing_map) - self.uses[tensor_use] = use_config - - def _get_scalar_map(self) -> _ir.AffineMap: - """Create an empty affine map used to index a scalar.""" - with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count, - symbol_count=self.affine_state.symbol_count, - exprs=list()) - - def _normalize_affine_map(self, - affine_map: _ir.AffineMap, - with_dims: bool = True) -> _ir.AffineMap: - """Normalizes an indexing map to have the max known symbols and dims.""" - with self.context: - return _ir.AffineMap.get( - dim_count=self.affine_state.dim_count if with_dims else 0, - symbol_count=self.affine_state.symbol_count, - exprs=list(affine_map.results)) - - def to_yaml_custom_dict(self): - self_dict = dict(args=self.ordered_operands) - # TODO: Refactor the hierarchy internally when supporting more - # than static (preserving this serialized form). - self_dict["indexing_maps"] = LinalgIndexingMapsConfig( - static_indexing_maps=self.indexing_maps) - self_dict["iterator_types"] = self.iterator_types - self_dict["assignments"] = self.assignments - return self_dict - - def __repr__(self): - lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] - lines.append("operands=[") - for def_config in self.ordered_operands: - lines.append(f" {repr(def_config)}") - lines.append("], indexing_maps=[") - for m in self.indexing_maps: - lines.append(f" {repr(m)}") - lines.append(f"], iterator_types=[") - for t in self.iterator_types: - lines.append(f" {t}") - lines.append("])") - return "\n".join(lines) + f"LinalgIndexingMapsConfig must have one type of indexing map" f"(got none)" + ) + + +class LinalgStructuredOpConfig(YAMLObject): + """Configuration for metadata sufficient to construct a linalg named op.""" + + yaml_tag = "!LinalgStructuredOpConfig" + + def __init__( + self, + comprehension: Comprehension, + domain: Sequence[DimDef], + registered_operands: Sequence[OperandDef], + context: Optional[_ir.Context] = None, + ): + self.context = context if context is not None else _ir.Context() + self.affine_state = AffineBuildState() + self.writes = list() # type: List[Tuple[TensorUse, TensorExpression]] + self.operands = dict() # type: Dict[OperandDef, OperandDefConfig] + self.uses = dict() # type: Dict[TensorUse, TensorUseConfig] + + # Compute the ordered set of writes and collect the tensor, capture, dims, + # and index uses. + collected_tensor_uses = set() + collected_scalar_uses = set() + collected_dim_uses = set() + collected_indices = set() + for write_use, read_use in zip(comprehension.definitions, comprehension.values): + self.writes.append((write_use, read_use)) + + for write_use, read_use in self.writes: + collected_tensor_uses.add(write_use) + read_use.collect_tensor_uses(collected_tensor_uses) + read_use.collect_scalar_uses(collected_scalar_uses) + read_use.collect_dim_uses(collected_dim_uses) + write_use.collect_dim_uses(collected_dim_uses) + read_use.collect_indices(collected_indices) + + # Set domain to the sorted list of uses if no domain annotation is given. + if not domain: + domain = sorted(collected_dim_uses, key=lambda dim: dim.dimname) + + # Verify the domain dimensions match the used dimensions. + if len(domain) != len(collected_dim_uses) or any( + dim not in collected_dim_uses for dim in domain + ): + raise ValueError( + f"Expected the annotated domain dimensions {domain} to " + f"match the set of dimension used by the tensor " + f"comprehension {collected_dim_uses}" + ) + + # Instantiate the dimensions in the given order. + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + for dim in domain: + dim.build(state=local_state) + + # Collect all attribute definitions. + collected_attr_defs = list() + for operand in registered_operands: + if operand.is_attribute(): + collected_attr_defs.append(operand) + + # Collect all tensors with manual indexing annotation. + collected_index_defs = list() + for operand in registered_operands: + if operand.index_dims: + if any(dim not in collected_dim_uses for dim in operand.index_dims): + raise ValueError( + f"Expected all index dims {operand.index_dims} of " + f"operand {operand.name} to have uses." + ) + collected_index_defs.append(operand) + + # Collect the operand definitions of all tensor/scalar uses, attributes, and + # shape-only tensors. + all_operand_defs = list() + for use in collected_tensor_uses: + all_operand_defs.append(use.operand_def) + for use in collected_scalar_uses: + all_operand_defs.append(use.operand_def) + for definition in collected_attr_defs: + all_operand_defs.append(definition) + for definition in collected_index_defs: + all_operand_defs.append(definition) + + # Add all operands in registration order to ensure the symbols are + # registered in the order they appear. + all_operand_defs = sorted( + all_operand_defs, key=lambda operand_def: operand_def.registered_index + ) + for operand_def in all_operand_defs: + self.add_operand(operand_def) + + # Add all shape-only tensor index_dim annotations and all tensor uses. + for definition in collected_index_defs: + self.add_indexed_operand(definition) + for use in collected_tensor_uses: + self.add_tensor_use(use) + + # Normalize all shape and indexing maps now that full count of dims and + # symbols are known. + for cuse in self.uses.values(): + cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map) + for definition in collected_index_defs: + self.operands[definition].indexing_map = self._normalize_affine_map( + self.operands[definition].indexing_map + ) + for operand_config in self.operands.values(): + if operand_config.shape_map: + operand_config.shape_map = self._normalize_affine_map( + operand_config.shape_map, with_dims=False + ) + if operand_config.index_attr_map: + operand_config.index_attr_map = self._normalize_affine_map( + operand_config.index_attr_map, with_dims=False + ) + + # Now for each write use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for write_use, _ in self.writes: + write_tensor_config = self.operands[write_use.operand_def] + if write_tensor_config.indexing_map: + raise ValueError( + f"Unexpected multi-write to a single tensor: {write_tensor_config}" + ) + write_tensor_config.indexing_map = self.uses[write_use].indexing_map + + # For each read use, propagate the indexing maps from the use to the + # tensor, ensuring that there are not conflicts. + for _, read_expr in self.writes: + read_uses = set() # type: Set[TensorUse] + read_expr.collect_tensor_uses(read_uses) + for read_use in read_uses: + read_operand_config = self.operands[read_use.operand_def] + if ( + read_operand_config.indexing_map + and read_operand_config.indexing_map + != self.uses[read_use].indexing_map + ): + raise ValueError( + f"Unexpected multi-read of a tensor with different accesses:" + f"{read_operand_config} vs {read_use}" + ) + read_operand_config.indexing_map = self.uses[read_use].indexing_map + + # Set the indexing map of all scalar uses to the empty map. + for operand_config in self.operands.values(): + if operand_config.operand_def.kind == OperandKind.SCALAR: + operand_config.indexing_map = self._get_scalar_map() + + # Check all registered tensor and scalar operands have an indexing map. + for operand in registered_operands: + if operand.is_attribute(): + continue + if not (operand in self.operands and self.operands[operand].indexing_map): + raise ValueError( + f"Failed to compute an indexing map for operand " f"{operand.name}" + ) + + # Collect reduction dims and ensure all the same. + all_reduction_dims = set(comprehension.all_reduction_dims) + if len(all_reduction_dims) != 1: + raise ValueError( + f"All writes within a generic must have the same reduction " + f"dims. Got: {all_reduction_dims}" + ) + self.reduction_dims = next(iter(all_reduction_dims)) + + # Check the index dimension exists and resolve. + for index in collected_indices: + if index.dim_def.dimname not in self.affine_state.all_dims: + raise ValueError( + f"The dimension {index.dim_def.dimname} is not part of the " + f"iteration domain {self.affine_state.all_dims}" + ) + index.resolve_dimension_name(self.affine_state) + + # Generate the scalar assignments (used to build a body). + self.assignments = [ + ScalarAssign(write_use.tensor_name, read_expr.to_scalar_expression()) + for write_use, read_expr in self.writes + ] + + @property + def ordered_operands(self) -> Sequence[OperandDefConfig]: + return sorted( + self.operands.values(), + key=lambda operand: operand.operand_def.registered_index, + ) + + @property + def ordered_dims(self) -> Sequence[Tuple[str, int]]: + """Gets the ordered list of dim bindings (symbolic name, position). + + TODO: The original parser relies on parse ordering to arrive at the + iterator types, but that ordering is not defined on the Python side, so + this may be ambiguous. + """ + return list(self.affine_state.all_dims.items()) + + @property + def indexing_maps(self) -> Sequence[_ir.AffineMap]: + return [o.indexing_map for o in self.ordered_operands if o.indexing_map] + + @property + def iterator_types(self) -> Sequence[str]: + def get_type(symbolic_name, position): + for reduction_dim_expr in self.reduction_dims: + if reduction_dim_expr.dimname == symbolic_name: + return "reduction" + return "parallel" + + return [get_type(*dim) for dim in self.ordered_dims] + + def add_operand(self, operand_def: OperandDef): + if operand_def in self.operands: + return + if not (operand_def.is_tensor() or operand_def.kind == OperandKind.INDEX_ATTR): + self.operands[operand_def] = OperandDefConfig(operand_def) + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_dims=False + ) + exprs = [] + for expr in operand_def.size_exprs: + exprs.append(expr.build(state=local_state)) + assert local_state.local_dim_count == 0 + affine_map = _ir.AffineMap.get( + dim_count=0, symbol_count=local_state.symbol_count, exprs=exprs + ) + if operand_def.kind == OperandKind.INDEX_ATTR: + self.operands[operand_def] = OperandDefConfig( + operand_def, index_attr_map=affine_map + ) + else: + self.operands[operand_def] = OperandDefConfig( + operand_def, shape_map=affine_map + ) + + def add_indexed_operand(self, operand_def: OperandDef): + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in operand_def.index_dims: + exprs.append(expr.build(state=local_state)) + self.operands[operand_def].indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + def add_tensor_use(self, tensor_use: TensorUse): + if tensor_use in self.uses: + return + with self.context: + local_state = AffineBuildState( + global_state=self.affine_state, allow_new_symbols=False + ) + exprs = [] + for expr in tensor_use.indices: + exprs.append(expr.build(state=local_state)) + indexing_map = _ir.AffineMap.get( + dim_count=local_state.dim_count, + symbol_count=local_state.symbol_count, + exprs=exprs, + ) + + use_config = TensorUseConfig(tensor_use, indexing_map) + self.uses[tensor_use] = use_config + + def _get_scalar_map(self) -> _ir.AffineMap: + """Create an empty affine map used to index a scalar.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count, + symbol_count=self.affine_state.symbol_count, + exprs=list(), + ) + + def _normalize_affine_map( + self, affine_map: _ir.AffineMap, with_dims: bool = True + ) -> _ir.AffineMap: + """Normalizes an indexing map to have the max known symbols and dims.""" + with self.context: + return _ir.AffineMap.get( + dim_count=self.affine_state.dim_count if with_dims else 0, + symbol_count=self.affine_state.symbol_count, + exprs=list(affine_map.results), + ) + + def to_yaml_custom_dict(self): + self_dict = dict(args=self.ordered_operands) + # TODO: Refactor the hierarchy internally when supporting more + # than static (preserving this serialized form). + self_dict["indexing_maps"] = LinalgIndexingMapsConfig( + static_indexing_maps=self.indexing_maps + ) + self_dict["iterator_types"] = self.iterator_types + self_dict["assignments"] = self.assignments + return self_dict + + def __repr__(self): + lines = [f"LinalgGenericOpConfig(reduction_dims={self.reduction_dims},"] + lines.append("operands=[") + for def_config in self.ordered_operands: + lines.append(f" {repr(def_config)}") + lines.append("], indexing_maps=[") + for m in self.indexing_maps: + lines.append(f" {repr(m)}") + lines.append(f"], iterator_types=[") + for t in self.iterator_types: + lines.append(f" {t}") + lines.append("])") + return "\n".join(lines) class LinalgOpConfig(YAMLObject): - """Container for any supported linalg op type. - - This includes the concrete type by name for ease of parsing by systems - that ignore tags. - """ - yaml_tag = "!LinalgOpConfig" - - def __init__(self, - metadata: OpMetadataDef, - *, - structured_op: Optional[LinalgStructuredOpConfig] = None): - self.metadata = metadata - self.structured_op = structured_op - - def to_yaml_custom_dict(self): - self_dict = dict(metadata=self.metadata,) - if self.structured_op: - self_dict["structured_op"] = self.structured_op - return self_dict - - @staticmethod - def from_linalg_op_def( - op_def: LinalgOpDef, - context: Optional[_ir.Context] = None) -> Sequence["LinalgOpConfig"]: - """Expands a LinalgOpDef into corresponding Linalg configured ops.""" - # TODO: Many LinalgOpDef patterns need to expand to multiple generics. - assert len(op_def.comprehensions) == 1, "Only one comprehension supported" - return [ - LinalgOpConfig( - op_def.metadata, - structured_op=LinalgStructuredOpConfig( - op_def.comprehensions[0], op_def.domain, - op_def.registered_operands.values(), context)), - ] - - def __repr__(self): - return (f"LinalgOpConfig(metadata={self.metadata},\n" - f"structured_op={self.structured_op})") + """Container for any supported linalg op type. + + This includes the concrete type by name for ease of parsing by systems + that ignore tags. + """ + + yaml_tag = "!LinalgOpConfig" + + def __init__( + self, + metadata: OpMetadataDef, + *, + structured_op: Optional[LinalgStructuredOpConfig] = None, + ): + self.metadata = metadata + self.structured_op = structured_op + + def to_yaml_custom_dict(self): + self_dict = dict( + metadata=self.metadata, + ) + if self.structured_op: + self_dict["structured_op"] = self.structured_op + return self_dict + + @staticmethod + def from_linalg_op_def( + op_def: LinalgOpDef, context: Optional[_ir.Context] = None + ) -> Sequence["LinalgOpConfig"]: + """Expands a LinalgOpDef into corresponding Linalg configured ops.""" + # TODO: Many LinalgOpDef patterns need to expand to multiple generics. + assert len(op_def.comprehensions) == 1, "Only one comprehension supported" + return [ + LinalgOpConfig( + op_def.metadata, + structured_op=LinalgStructuredOpConfig( + op_def.comprehensions[0], + op_def.domain, + op_def.registered_operands.values(), + context, + ), + ), + ] + + def __repr__(self): + return ( + f"LinalgOpConfig(metadata={self.metadata},\n" + f"structured_op={self.structured_op})" + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py index 45b8d5ccd..8b8726f8f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py @@ -10,160 +10,192 @@ import threading from ..... import ir -from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) from .comprehension import * from .config import * from .emitter import * _CONTEXT = threading.local() -StructuredOpOuts = Union[ir.Operation, ir.OpView, ir.OpResultList, - Sequence[Union[ir.Value, ir.Operation, ir.OpView]]] +StructuredOpOuts = Union[ + ir.Operation, + ir.OpView, + ir.OpResultList, + Sequence[Union[ir.Value, ir.Operation, ir.OpView]], +] @contextmanager def bind_op_def(op_def: LinalgOpDef): - if hasattr(_CONTEXT, "current_op_def"): - raise ValueError("Cannot recursively define an operation") - _CONTEXT.current_op_def = op_def - try: - yield op_def - finally: - del _CONTEXT.current_op_def + if hasattr(_CONTEXT, "current_op_def"): + raise ValueError("Cannot recursively define an operation") + _CONTEXT.current_op_def = op_def + try: + yield op_def + finally: + del _CONTEXT.current_op_def def current_op_def() -> LinalgOpDef: - try: - return _CONTEXT.current_op_def - except AttributeError: - raise ValueError( - "Attempt to access the current op definition being defined " - "but none is set. Did you mean to call this in an op definition?") + try: + return _CONTEXT.current_op_def + except AttributeError: + raise ValueError( + "Attempt to access the current op definition being defined " + "but none is set. Did you mean to call this in an op definition?" + ) def _prepare_structured_op_outs(outs: StructuredOpOuts) -> ValueList: - if isinstance(outs, (ir.Operation, ir.OpView)): - return _get_op_results_or_values(outs) - elif isinstance(outs, ir.OpResultList): - return outs + if isinstance(outs, (ir.Operation, ir.OpView)): + return _get_op_results_or_values(outs) + elif isinstance(outs, ir.OpResultList): + return outs - return [_get_op_result_or_value(o) for o in outs] + return [_get_op_result_or_value(o) for o in outs] class DefinedOpCallable: - """Callable that wraps any defined op function.""" - - def __init__(self, op_name: str, op_def: LinalgOpDef): - self.op_name = op_name - self.op_def = op_def - - def __call__(self, *ins: Union[ir.Operation, ir.OpView, ir.Value], - outs: StructuredOpOuts, **kwargs): - """Emits the corresponding op definition as IR. - - Most arguments are passed through to the underlying emitter. The following - keyword argument is interpreted here: - emit_generic: Emits a generic form as appropriate (default True). If - False, a named form is emitted (which must have been built in to the - compiler). - """ - emit_generic = kwargs.pop("emit_generic", False) - if not isinstance(emit_generic, bool): - raise ValueError(f"The named argument 'emit_generic' needs to be " - f" of type bool but got {type(emit_generic)}") - - op_configs = LinalgOpConfig.from_linalg_op_def( - self.op_def, context=ir.Context.current) - - if len(op_configs) != 1: - # TODO: Support composite ops. - raise NotImplementedError( - f"Emission of composite linalg ops not supported: {op_configs}") - - ctx = ir.Context.current - linalgDialect = ctx.get_dialect_descriptor("linalg") - fully_qualified_name = "linalg." + self.op_name - emit_generic = ( - emit_generic or not ctx.is_registered_operation(fully_qualified_name)) - - op_config = op_configs[0] - out_values = _prepare_structured_op_outs(outs) - in_values = [_get_op_result_or_value(i) for i in ins] - if op_config.structured_op: - if emit_generic: - return emit_generic_structured_op( - op_config.structured_op, *in_values, outs=out_values, **kwargs) - else: - return emit_named_structured_op( - op_config.structured_op, - self.op_name, - self.op_def.metadata.cpp_class_name, - *in_values, - outs=out_values, - **kwargs) - - raise NotImplementedError( - f"Emission of linalg op type not supported: {op_config}") - - -def linalg_structured_op(dsl_func=None, - *, - op_name=None, - op_class_name=None) -> DefinedOpCallable: - if dsl_func is None: - # Curry the keyword args in for delayed application. - return functools.partial( - linalg_structured_op, op_name=op_name, op_class_name=op_class_name) - # Determine default names by introspecting the function. - if op_name is None: - op_name = dsl_func.__name__ - if op_class_name is None: - # Camel case it. - op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" - - op_def = LinalgOpDef( - name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func)) - - # Extract arguments and TensorDefs from the signature. - dsl_func_args = list() - sig = inspect.signature(dsl_func) - for param_name, param in sig.parameters.items(): - param_default = param.default - if isinstance(param_default, - (TensorDef, ScalarDef, IndexAttrDef, UnaryFnAttrDef, - BinaryFnAttrDef, TypeFnAttrDef)): - op_def.add_operand(param_name, param_default.operand_def) - else: - raise ValueError( - f"@linalg_structured_op function parameters must be defaulted as " - f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " - f"Found {param_name}: {param_default}") - dsl_func_args.append(param_default) - - # Invoke the DSL func to finish populating the op definition. - with bind_op_def(op_def): - dsl_func(*dsl_func_args) - - # TODO: The returned callable should be an IR emitter but that is not - # upstreamed yet. - return DefinedOpCallable(op_name, op_def) + """Callable that wraps any defined op function.""" + + def __init__(self, op_name: str, op_def: LinalgOpDef): + self.op_name = op_name + self.op_def = op_def + + def __call__( + self, + *ins: Union[ir.Operation, ir.OpView, ir.Value], + outs: StructuredOpOuts, + **kwargs, + ): + """Emits the corresponding op definition as IR. + + Most arguments are passed through to the underlying emitter. The following + keyword argument is interpreted here: + emit_generic: Emits a generic form as appropriate (default True). If + False, a named form is emitted (which must have been built in to the + compiler). + """ + emit_generic = kwargs.pop("emit_generic", False) + if not isinstance(emit_generic, bool): + raise ValueError( + f"The named argument 'emit_generic' needs to be " + f" of type bool but got {type(emit_generic)}" + ) + + op_configs = LinalgOpConfig.from_linalg_op_def( + self.op_def, context=ir.Context.current + ) + + if len(op_configs) != 1: + # TODO: Support composite ops. + raise NotImplementedError( + f"Emission of composite linalg ops not supported: {op_configs}" + ) + + ctx = ir.Context.current + linalgDialect = ctx.get_dialect_descriptor("linalg") + fully_qualified_name = "linalg." + self.op_name + emit_generic = emit_generic or not ctx.is_registered_operation( + fully_qualified_name + ) + + op_config = op_configs[0] + out_values = _prepare_structured_op_outs(outs) + in_values = [_get_op_result_or_value(i) for i in ins] + if op_config.structured_op: + if emit_generic: + return emit_generic_structured_op( + op_config.structured_op, *in_values, outs=out_values, **kwargs + ) + else: + return emit_named_structured_op( + op_config.structured_op, + self.op_name, + self.op_def.metadata.cpp_class_name, + *in_values, + outs=out_values, + **kwargs, + ) + + raise NotImplementedError( + f"Emission of linalg op type not supported: {op_config}" + ) + + +def linalg_structured_op( + dsl_func=None, *, op_name=None, op_class_name=None +) -> DefinedOpCallable: + if dsl_func is None: + # Curry the keyword args in for delayed application. + return functools.partial( + linalg_structured_op, op_name=op_name, op_class_name=op_class_name + ) + # Determine default names by introspecting the function. + if op_name is None: + op_name = dsl_func.__name__ + if op_class_name is None: + # Camel case it. + op_class_name = f"{''.join(x.title() for x in op_name.split('_'))}Op" + + op_def = LinalgOpDef( + name=op_name, cpp_class_name=op_class_name, doc=inspect.getdoc(dsl_func) + ) + + # Extract arguments and TensorDefs from the signature. + dsl_func_args = list() + sig = inspect.signature(dsl_func) + for param_name, param in sig.parameters.items(): + param_default = param.default + if isinstance( + param_default, + ( + TensorDef, + ScalarDef, + IndexAttrDef, + UnaryFnAttrDef, + BinaryFnAttrDef, + TypeFnAttrDef, + ), + ): + op_def.add_operand(param_name, param_default.operand_def) + else: + raise ValueError( + f"@linalg_structured_op function parameters must be defaulted as " + f"TensorDef(...), ScalarDef(...), or IndexAttrDef(...): " + f"Found {param_name}: {param_default}" + ) + dsl_func_args.append(param_default) + + # Invoke the DSL func to finish populating the op definition. + with bind_op_def(op_def): + dsl_func(*dsl_func_args) + + # TODO: The returned callable should be an IR emitter but that is not + # upstreamed yet. + return DefinedOpCallable(op_name, op_def) def domain(*dimensions: DimDef): - if any(not isinstance(d, DimDef) for d in dimensions): - raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") - current_op_def().domain.extend(dimensions) + if any(not isinstance(d, DimDef) for d in dimensions): + raise ValueError(f"Expected dimensions of type DimDef but got {dimensions}") + current_op_def().domain.extend(dimensions) def implements(*interfaces: OpInterfaceDef): - if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): - raise ValueError( - f"Expected interfaces of type OpInterfaceDef but got {interfaces}") - current_op_def().metadata.implements.extend(interfaces) + if any(not isinstance(intr, OpInterfaceDef) for intr in interfaces): + raise ValueError( + f"Expected interfaces of type OpInterfaceDef but got {interfaces}" + ) + current_op_def().metadata.implements.extend(interfaces) def defines(*definitions: OpDefinitionDef): - if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): - raise ValueError( - f"Expected definitions of type OpDefinitionDef but got {definitions}") - current_op_def().metadata.defines.extend(definitions) + if any(not isinstance(defi, OpDefinitionDef) for defi in definitions): + raise ValueError( + f"Expected definitions of type OpDefinitionDef but got {definitions}" + ) + current_op_def().metadata.defines.extend(definitions) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index b63cb4071..62730d9ca 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -11,7 +11,10 @@ from .... import math from .... import arith from .... import complex -from ...._ods_common import get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values +from ...._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, +) from .scalar_expr import * from .config import * @@ -29,529 +32,618 @@ def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False - - -def prepare_common_structured_op(op_config: LinalgStructuredOpConfig, - *ins: Value, outs: ValueList, - **attrs: Union[Sequence[int], TypeFnType]): - all_arg_defs = op_config.ordered_operands - in_arg_defs = [ - d for d in all_arg_defs - if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] - ] - out_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR - ] - index_attr_arg_defs = [ - d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR - ] - fn_attr_arg_defs = [ - d for d in all_arg_defs if d.kind in [ - OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR, - OperandKind.TYPE_FN_ATTR - ] - ] - - # Verify outs is a sequence or a list of results. - if not isinstance(outs, (Sequence, OpResultList)): - raise ValueError(f"Expected named argument outs to have type Sequence or " - f"OpResultLis but got {type(outs)}") - - # Arity validation. - if len(ins) != len(in_arg_defs): - raise ValueError(f"Expected {len(in_arg_defs)} inputs but got " - f"{len(ins)} for {op_config}") - if outs and len(outs) != len(out_arg_defs): - raise ValueError(f"Expected {len(out_arg_defs)} outputs but got " - f"{len(outs)} for {op_config}") - - # Compute a replacement list for all index attribute symbols. - expressions = [] # type: Sequence[AffineExpr] - replacements = [] # type: Sequence[AffineExpr] - for index_attr in index_attr_arg_defs: - index_attr_vals = index_attr.operand_def.default_indices - if index_attr.name in attrs: - index_attr_vals = attrs.get(index_attr.name) - assert index_attr_vals, "Index attribute has no value" - if not all(isinstance(value, int) for value in index_attr_vals): - raise ValueError(f"Attribute {index_attr.name} needs to be of type " - f"Sequence[int] but got {type(index_attr_vals)}") - results = index_attr.index_attr_map.results # type: AffineExprList - if len(index_attr_vals) != len(results): - raise ValueError(f"Attribute {index_attr.name} has length {len(results)} " - f"but got {len(index_attr_vals)} values") - for expr, value in zip(results, index_attr_vals): - expressions.append(expr) - replacements.append(AffineConstantExpr.get(value)) - - # Replace all index attribute symbols by their value. - # TODO: Add support for shape symbols. - indexing_maps = [] # type: Sequence[AffineMap] - for curr in op_config.indexing_maps: - for expression, replacement in zip(expressions, replacements): - curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) - indexing_maps.append(curr) - - # TODO: Linalg verification does not currently allow symbols. - # Compress them for now and verify none are left. - indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, - Context.current) - if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): - raise ValueError(f"Expected indexing_maps to use no symbols after " - f"replacement and compression but got {indexing_maps}") - - outs, out_types = _infer_structured_outs(op_config, in_arg_defs, ins, - out_arg_defs, outs) - - result_types = [t for t in out_types if isa(RankedTensorType, t)] - - # Initialize the type dictionary with the predefined types. - type_mapping = dict() # type: Dict[str, Type] - type_mapping["F32"] = F32Type.get() - type_mapping["F64"] = F64Type.get() - type_mapping["I32"] = IntegerType.get_signless(32) - type_mapping["I64"] = IntegerType.get_signless(64) - - # Extract type vars for input/output based types. - block_arg_types = list() # type: List[Type] - for arg_def, arg_element_type in zip(in_arg_defs + out_arg_defs, - _get_types_from_values(*ins, *outs)): - _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) - - # Emit the generic op. - # TODO: Support emission of pure memref form. - indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in indexing_maps]) - iterator_types_attr = ArrayAttr.get([ - Attribute.parse(f"#linalg.iterator_type<{s}>") - for s in op_config.iterator_types - ]) - - # Compute the index attributes used when emitting a named structured op. - index_attrs = {} # type: Dict[str, DenseElementAttr] - for index_attr in index_attr_arg_defs: - index_attr_vals = attrs.get(index_attr.name) - # Only forward attributes set to a non-default value. - if index_attr_vals: - array = np.array(index_attr_vals, dtype=np.int64) - index_attrs[index_attr.name] = DenseElementsAttr.get(array) - - # Compute the function attribute mapping. - fn_attr_mapping = {} - for fn_attr in fn_attr_arg_defs: - attr_val = fn_attr.operand_def.default_fn - attr_kind = fn_attr.kind - if fn_attr.name in attrs: - fn = attrs.get(fn_attr.name) - if attr_kind == OperandKind.UNARY_FN_ATTR: - if not isinstance(fn, UnaryFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"UnaryFnType but got {type(attr_val)}") - elif attr_kind == OperandKind.BINARY_FN_ATTR: - if not isinstance(fn, BinaryFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"BinaryFnType but got {type(attr_val)}") - else: - if not isinstance(fn, TypeFnType): - raise ValueError(f"Attribute {fn_attr.name} needs to be of type " - f"TypeFnType but got {type(attr_val)}") - attr_val = fn.fn_name - assert attr_val, "Function attribute has no value" - fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) - - return (all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, - type_mapping, indexing_maps_attr, iterator_types_attr, index_attrs, - fn_attr_mapping, block_arg_types) - - -def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value, - outs: ValueList, **attrs: Sequence[int]): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ - block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) - - # An operation that accesses only scalars and scalar/rank zero tensors is - # rank polymorhpic. We implement rank polymorphism by generating different - # indexing maps and iterators that match the rank of the first output tensor. - # An operation is rank polymorphic if the iteration domain has rank zero. - if not iterator_types_attr: - rank = ShapedType(outs[0].type).rank - iterator_types_attr = ArrayAttr.get( - [Attribute.parse("#linalg.iterator_type")] * rank) - scalar_map = AffineMap.get(rank, 0, []) - tensor_map = AffineMap.get_identity(rank) - indexing_maps = [] - for arg_def in all_arg_defs: - if arg_def.operand_def.kind == OperandKind.SCALAR: - indexing_maps.append(scalar_map) - if arg_def.operand_def.is_tensor(): - idx = arg_def.operand_def.registered_index - if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: - indexing_maps.append(scalar_map) - else: - indexing_maps.append(tensor_map) - indexing_maps_attr = ArrayAttr.get( - [AffineMapAttr.get(am) for am in indexing_maps]) - - generic_op = linalg.GenericOp( - result_tensors=result_types, - inputs=ins, - outputs=outs, - indexing_maps=indexing_maps_attr, - iterator_types=iterator_types_attr, - doc=None, # TODO: Make optional. - library_call=None) # TODO: Make optional. - - # Construct the body. - block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) - block = generic_op.regions[0].blocks.append(*block_arg_types) - block_arg_mapping = dict(zip(block_arg_names, block.arguments)) - with InsertionPoint(block): - body_builder = _BodyBuilder(type_mapping, block_arg_mapping, - fn_attr_mapping) - for assignment in op_config.assignments: - body_builder.assign(assignment) - body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) - - if len(result_types) == 1: - return generic_op.result - else: - return generic_op.results - - -def emit_named_structured_op(op_config: LinalgStructuredOpConfig, op_name: str, - op_class_name: str, *ins: Value, outs: ValueList, - **attrs: Sequence[int]): - all_arg_defs, in_arg_defs, out_arg_defs, outs, result_types, type_mapping, \ - indexing_maps_attr, iterator_types_attr, index_attrs, fn_attr_mapping, \ - block_arg_types = \ - prepare_common_structured_op(op_config, *ins, outs = outs, **attrs) - - # If we get here, there must exist a builtin class `op_class_name`. - ctx = Context.current - fully_qualified_name = "linalg." + op_name - if (not ctx.is_registered_operation(fully_qualified_name) or - not op_class_name in linalg.__dict__.keys()): - raise NotImplementedError( - f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}") - - # Set the index attributes used to compute the indexing maps. - named_op = getattr(linalg, op_class_name)(ins, outs, result_types) - for name, value in index_attrs.items(): - named_op.operation.attributes[name] = value - - # Compute the function attributes by combining operand kind and function name. - for name, (fn_name, kind) in fn_attr_mapping.items(): - assert kind.name.lower().endswith("_attr") - enum_name = kind.name.lower()[:-5] - named_op.operation.attributes[name] = Attribute.parse( - f"#linalg.{enum_name}<{fn_name}>") + try: + cls(ty) + return True + except ValueError: + return False - linalg.fill_builtin_region(named_op.operation) - if len(result_types) == 1: - return named_op.result - else: - return named_op.results +def prepare_common_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Union[Sequence[int], TypeFnType], +): + all_arg_defs = op_config.ordered_operands + in_arg_defs = [ + d + for d in all_arg_defs + if d.kind in [OperandKind.SCALAR, OperandKind.INPUT_TENSOR] + ] + out_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.OUTPUT_TENSOR] + index_attr_arg_defs = [d for d in all_arg_defs if d.kind == OperandKind.INDEX_ATTR] + fn_attr_arg_defs = [ + d + for d in all_arg_defs + if d.kind + in [ + OperandKind.UNARY_FN_ATTR, + OperandKind.BINARY_FN_ATTR, + OperandKind.TYPE_FN_ATTR, + ] + ] + + # Verify outs is a sequence or a list of results. + if not isinstance(outs, (Sequence, OpResultList)): + raise ValueError( + f"Expected named argument outs to have type Sequence or " + f"OpResultLis but got {type(outs)}" + ) + + # Arity validation. + if len(ins) != len(in_arg_defs): + raise ValueError( + f"Expected {len(in_arg_defs)} inputs but got " f"{len(ins)} for {op_config}" + ) + if outs and len(outs) != len(out_arg_defs): + raise ValueError( + f"Expected {len(out_arg_defs)} outputs but got " + f"{len(outs)} for {op_config}" + ) + + # Compute a replacement list for all index attribute symbols. + expressions = [] # type: Sequence[AffineExpr] + replacements = [] # type: Sequence[AffineExpr] + for index_attr in index_attr_arg_defs: + index_attr_vals = index_attr.operand_def.default_indices + if index_attr.name in attrs: + index_attr_vals = attrs.get(index_attr.name) + assert index_attr_vals, "Index attribute has no value" + if not all(isinstance(value, int) for value in index_attr_vals): + raise ValueError( + f"Attribute {index_attr.name} needs to be of type " + f"Sequence[int] but got {type(index_attr_vals)}" + ) + results = index_attr.index_attr_map.results # type: AffineExprList + if len(index_attr_vals) != len(results): + raise ValueError( + f"Attribute {index_attr.name} has length {len(results)} " + f"but got {len(index_attr_vals)} values" + ) + for expr, value in zip(results, index_attr_vals): + expressions.append(expr) + replacements.append(AffineConstantExpr.get(value)) + + # Replace all index attribute symbols by their value. + # TODO: Add support for shape symbols. + indexing_maps = [] # type: Sequence[AffineMap] + for curr in op_config.indexing_maps: + for expression, replacement in zip(expressions, replacements): + curr = curr.replace(expression, replacement, curr.n_dims, curr.n_symbols) + indexing_maps.append(curr) + + # TODO: Linalg verification does not currently allow symbols. + # Compress them for now and verify none are left. + indexing_maps = AffineMap.compress_unused_symbols(indexing_maps, Context.current) + if any(indexing_map.n_symbols != 0 for indexing_map in indexing_maps): + raise ValueError( + f"Expected indexing_maps to use no symbols after " + f"replacement and compression but got {indexing_maps}" + ) + + outs, out_types = _infer_structured_outs( + op_config, in_arg_defs, ins, out_arg_defs, outs + ) + + result_types = [t for t in out_types if isa(RankedTensorType, t)] + + # Initialize the type dictionary with the predefined types. + type_mapping = dict() # type: Dict[str, Type] + type_mapping["F32"] = F32Type.get() + type_mapping["F64"] = F64Type.get() + type_mapping["I32"] = IntegerType.get_signless(32) + type_mapping["I64"] = IntegerType.get_signless(64) + + # Extract type vars for input/output based types. + block_arg_types = list() # type: List[Type] + for arg_def, arg_element_type in zip( + in_arg_defs + out_arg_defs, _get_types_from_values(*ins, *outs) + ): + _add_type_mapping(arg_def, arg_element_type, type_mapping, block_arg_types) + + # Emit the generic op. + # TODO: Support emission of pure memref form. + indexing_maps_attr = ArrayAttr.get([AffineMapAttr.get(am) for am in indexing_maps]) + iterator_types_attr = ArrayAttr.get( + [ + Attribute.parse(f"#linalg.iterator_type<{s}>") + for s in op_config.iterator_types + ] + ) + + # Compute the index attributes used when emitting a named structured op. + index_attrs = {} # type: Dict[str, DenseElementAttr] + for index_attr in index_attr_arg_defs: + index_attr_vals = attrs.get(index_attr.name) + # Only forward attributes set to a non-default value. + if index_attr_vals: + array = np.array(index_attr_vals, dtype=np.int64) + index_attrs[index_attr.name] = DenseElementsAttr.get(array) + + # Compute the function attribute mapping. + fn_attr_mapping = {} + for fn_attr in fn_attr_arg_defs: + attr_val = fn_attr.operand_def.default_fn + attr_kind = fn_attr.kind + if fn_attr.name in attrs: + fn = attrs.get(fn_attr.name) + if attr_kind == OperandKind.UNARY_FN_ATTR: + if not isinstance(fn, UnaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"UnaryFnType but got {type(attr_val)}" + ) + elif attr_kind == OperandKind.BINARY_FN_ATTR: + if not isinstance(fn, BinaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"BinaryFnType but got {type(attr_val)}" + ) + else: + if not isinstance(fn, TypeFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TypeFnType but got {type(attr_val)}" + ) + attr_val = fn.fn_name + assert attr_val, "Function attribute has no value" + fn_attr_mapping[fn_attr.name] = (attr_val, attr_kind) + + return ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) + + +def emit_generic_structured_op( + op_config: LinalgStructuredOpConfig, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # An operation that accesses only scalars and scalar/rank zero tensors is + # rank polymorhpic. We implement rank polymorphism by generating different + # indexing maps and iterators that match the rank of the first output tensor. + # An operation is rank polymorphic if the iteration domain has rank zero. + if not iterator_types_attr: + rank = ShapedType(outs[0].type).rank + iterator_types_attr = ArrayAttr.get( + [Attribute.parse("#linalg.iterator_type")] * rank + ) + scalar_map = AffineMap.get(rank, 0, []) + tensor_map = AffineMap.get_identity(rank) + indexing_maps = [] + for arg_def in all_arg_defs: + if arg_def.operand_def.kind == OperandKind.SCALAR: + indexing_maps.append(scalar_map) + if arg_def.operand_def.is_tensor(): + idx = arg_def.operand_def.registered_index + if idx < len(ins) and ShapedType(ins[idx].type).rank == 0: + indexing_maps.append(scalar_map) + else: + indexing_maps.append(tensor_map) + indexing_maps_attr = ArrayAttr.get( + [AffineMapAttr.get(am) for am in indexing_maps] + ) + + generic_op = linalg.GenericOp( + result_tensors=result_types, + inputs=ins, + outputs=outs, + indexing_maps=indexing_maps_attr, + iterator_types=iterator_types_attr, + doc=None, # TODO: Make optional. + library_call=None, + ) # TODO: Make optional. + + # Construct the body. + block_arg_names = _get_operand_def_names(*in_arg_defs, *out_arg_defs) + block = generic_op.regions[0].blocks.append(*block_arg_types) + block_arg_mapping = dict(zip(block_arg_names, block.arguments)) + with InsertionPoint(block): + body_builder = _BodyBuilder(type_mapping, block_arg_mapping, fn_attr_mapping) + for assignment in op_config.assignments: + body_builder.assign(assignment) + body_builder.yield_outputs(*_get_operand_def_names(*out_arg_defs)) + + if len(result_types) == 1: + return generic_op.result + else: + return generic_op.results + + +def emit_named_structured_op( + op_config: LinalgStructuredOpConfig, + op_name: str, + op_class_name: str, + *ins: Value, + outs: ValueList, + **attrs: Sequence[int], +): + ( + all_arg_defs, + in_arg_defs, + out_arg_defs, + outs, + result_types, + type_mapping, + indexing_maps_attr, + iterator_types_attr, + index_attrs, + fn_attr_mapping, + block_arg_types, + ) = prepare_common_structured_op(op_config, *ins, outs=outs, **attrs) + + # If we get here, there must exist a builtin class `op_class_name`. + ctx = Context.current + fully_qualified_name = "linalg." + op_name + if ( + not ctx.is_registered_operation(fully_qualified_name) + or not op_class_name in linalg.__dict__.keys() + ): + raise NotImplementedError( + f"Unknown named op_name / op_class_name: {op_name} / {op_class_name}" + ) + + # Set the index attributes used to compute the indexing maps. + named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + for name, value in index_attrs.items(): + named_op.operation.attributes[name] = value + + # Compute the function attributes by combining operand kind and function name. + for name, (fn_name, kind) in fn_attr_mapping.items(): + assert kind.name.lower().endswith("_attr") + enum_name = kind.name.lower()[:-5] + named_op.operation.attributes[name] = Attribute.parse( + f"#linalg.{enum_name}<{fn_name}>" + ) + + linalg.fill_builtin_region(named_op.operation) + + if len(result_types) == 1: + return named_op.result + else: + return named_op.results class _BodyBuilder: - """Constructs a structured op body by evaluating assignments.""" - - def __init__(self, type_mapping: Dict[str, Type], - block_arg_mapping: Dict[str, Value], fn_attr_mapping: Dict[str, - str]): - self.type_mapping = type_mapping - self.block_arg_mapping = block_arg_mapping - self.fn_attr_mapping = fn_attr_mapping - self.yield_mapping = dict() # type: Dict[str, Value] - - def assign(self, assignment: ScalarAssign): - if assignment.arg in self.yield_mapping: - raise ValueError( - f"Multiple assignments to the same argument are forbidden: " - f"{assignment}") - self.yield_mapping[assignment.arg] = self.expression(assignment.value) - - def expression(self, expr: ScalarExpression) -> Value: - if expr.scalar_arg: - try: - return self.block_arg_mapping[expr.scalar_arg.arg] - except KeyError: - raise ValueError(f"Argument {expr.scalar_arg.arg} is not bound for " - f"this structured op.") - elif expr.scalar_const: - value_attr = Attribute.parse(expr.scalar_const.value) - return arith.ConstantOp(value_attr.type, value_attr).result - elif expr.scalar_index: - dim_attr = IntegerAttr.get( - IntegerType.get_signless(64), expr.scalar_index.dim) - return linalg.IndexOp(dim_attr).result - elif expr.scalar_fn: - kind = expr.scalar_fn.kind.name.lower() - fn_name = expr.scalar_fn.fn_name - if expr.scalar_fn.attr_name: - fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] - fn = self._get_function(f"_{kind}_{fn_name}") - operand_values = [ - self.expression(operand) for operand in expr.scalar_fn.operands - ] - if expr.scalar_fn.kind == FunctionKind.TYPE: - operand_values = [expr.scalar_fn.type_var.name] + operand_values - return fn(*operand_values) - raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") - - def yield_outputs(self, *output_names: str): - output_values = [] - for n in output_names: - try: - output_values.append(self.yield_mapping[n]) - except KeyError: - raise ValueError(f"Body assignments do not assign all outputs: " - f"missing '{n}'") - linalg.YieldOp(output_values) - - def _get_function(self, fn_name: str) -> Callable: - try: - fn = getattr(self, f"{fn_name}") - except AttributeError: - raise ValueError(f"Function '{fn_name}' is not a known function") - return fn - - def _cast(self, - type_var_name: str, - operand: Value, - is_unsigned_cast: bool = False) -> Value: - try: - to_type = self.type_mapping[type_var_name] - except KeyError: - raise ValueError(f"Unbound type variable '{type_var_name}' (" - f"expected one of {self.type_mapping.keys()}") - if operand.type == to_type: - return operand - if _is_integer_type(to_type): - return self._cast_to_integer(to_type, operand, is_unsigned_cast) - elif _is_floating_point_type(to_type): - return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) - - def _cast_to_integer(self, to_type: Type, operand: Value, - is_unsigned_cast: bool) -> Value: - to_width = IntegerType(to_type).width - operand_type = operand.type - if _is_floating_point_type(operand_type): - if is_unsigned_cast: - return arith.FPToUIOp(to_type, operand).result - return arith.FPToSIOp(to_type, operand).result - if _is_index_type(operand_type): - return arith.IndexCastOp(to_type, operand).result - # Assume integer. - from_width = IntegerType(operand_type).width - if to_width > from_width: - if is_unsigned_cast: - return arith.ExtUIOp(to_type, operand).result - return arith.ExtSIOp(to_type, operand).result - elif to_width < from_width: - return arith.TruncIOp(to_type, operand).result - raise ValueError(f"Unable to cast body expression from {operand_type} to " - f"{to_type}") - - def _cast_to_floating_point(self, to_type: Type, operand: Value, - is_unsigned_cast: bool) -> Value: - operand_type = operand.type - if _is_integer_type(operand_type): - if is_unsigned_cast: - return arith.UIToFPOp(to_type, operand).result - return arith.SIToFPOp(to_type, operand).result - # Assume FloatType. - to_width = _get_floating_point_width(to_type) - from_width = _get_floating_point_width(operand_type) - if to_width > from_width: - return arith.ExtFOp(to_type, operand).result - elif to_width < from_width: - return arith.TruncFOp(to_type, operand).result - raise ValueError(f"Unable to cast body expression from {operand_type} to " - f"{to_type}") - - def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: - return self._cast(type_var_name, operand, False) - - def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: - return self._cast(type_var_name, operand, True) - - def _unary_exp(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.ExpOp(x).result - raise NotImplementedError("Unsupported 'exp' operand: {x}") - - def _unary_log(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.LogOp(x).result - raise NotImplementedError("Unsupported 'log' operand: {x}") - - def _unary_abs(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.AbsFOp(x).result - raise NotImplementedError("Unsupported 'abs' operand: {x}") - - def _unary_ceil(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.CeilOp(x).result - raise NotImplementedError("Unsupported 'ceil' operand: {x}") - - def _unary_floor(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return math.FloorOp(x).result - raise NotImplementedError("Unsupported 'floor' operand: {x}") - - def _unary_negf(self, x: Value) -> Value: - if _is_floating_point_type(x.type): - return arith.NegFOp(x).result - if _is_complex_type(x.type): - return complex.NegOp(x).result - raise NotImplementedError("Unsupported 'negf' operand: {x}") - - def _binary_add(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.AddFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.AddIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.AddOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") - - def _binary_sub(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.SubFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.SubIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.SubOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") - - def _binary_mul(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MulFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MulIOp(lhs, rhs).result - if _is_complex_type(lhs.type): - return complex.MulOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") - - def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MaxSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") - - def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MaxUIOp(lhs, rhs).result - raise NotImplementedError( - "Unsupported 'max_unsigned' operands: {lhs}, {rhs}") - - def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MinSIOp(lhs, rhs).result - raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") - - def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: - if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result - if _is_integer_type(lhs.type) or _is_index_type(lhs.type): - return arith.MinUIOp(lhs, rhs).result - raise NotImplementedError( - "Unsupported 'min_unsigned' operands: {lhs}, {rhs}") + """Constructs a structured op body by evaluating assignments.""" + + def __init__( + self, + type_mapping: Dict[str, Type], + block_arg_mapping: Dict[str, Value], + fn_attr_mapping: Dict[str, str], + ): + self.type_mapping = type_mapping + self.block_arg_mapping = block_arg_mapping + self.fn_attr_mapping = fn_attr_mapping + self.yield_mapping = dict() # type: Dict[str, Value] + + def assign(self, assignment: ScalarAssign): + if assignment.arg in self.yield_mapping: + raise ValueError( + f"Multiple assignments to the same argument are forbidden: " + f"{assignment}" + ) + self.yield_mapping[assignment.arg] = self.expression(assignment.value) + + def expression(self, expr: ScalarExpression) -> Value: + if expr.scalar_arg: + try: + return self.block_arg_mapping[expr.scalar_arg.arg] + except KeyError: + raise ValueError( + f"Argument {expr.scalar_arg.arg} is not bound for " + f"this structured op." + ) + elif expr.scalar_const: + value_attr = Attribute.parse(expr.scalar_const.value) + return arith.ConstantOp(value_attr.type, value_attr).result + elif expr.scalar_index: + dim_attr = IntegerAttr.get( + IntegerType.get_signless(64), expr.scalar_index.dim + ) + return linalg.IndexOp(dim_attr).result + elif expr.scalar_fn: + kind = expr.scalar_fn.kind.name.lower() + fn_name = expr.scalar_fn.fn_name + if expr.scalar_fn.attr_name: + fn_name, _ = self.fn_attr_mapping[expr.scalar_fn.attr_name] + fn = self._get_function(f"_{kind}_{fn_name}") + operand_values = [ + self.expression(operand) for operand in expr.scalar_fn.operands + ] + if expr.scalar_fn.kind == FunctionKind.TYPE: + operand_values = [expr.scalar_fn.type_var.name] + operand_values + return fn(*operand_values) + raise NotImplementedError(f"Unimplemented scalar body expression: {expr}") + + def yield_outputs(self, *output_names: str): + output_values = [] + for n in output_names: + try: + output_values.append(self.yield_mapping[n]) + except KeyError: + raise ValueError( + f"Body assignments do not assign all outputs: " f"missing '{n}'" + ) + linalg.YieldOp(output_values) + + def _get_function(self, fn_name: str) -> Callable: + try: + fn = getattr(self, f"{fn_name}") + except AttributeError: + raise ValueError(f"Function '{fn_name}' is not a known function") + return fn + + def _cast( + self, type_var_name: str, operand: Value, is_unsigned_cast: bool = False + ) -> Value: + try: + to_type = self.type_mapping[type_var_name] + except KeyError: + raise ValueError( + f"Unbound type variable '{type_var_name}' (" + f"expected one of {self.type_mapping.keys()}" + ) + if operand.type == to_type: + return operand + if _is_integer_type(to_type): + return self._cast_to_integer(to_type, operand, is_unsigned_cast) + elif _is_floating_point_type(to_type): + return self._cast_to_floating_point(to_type, operand, is_unsigned_cast) + + def _cast_to_integer( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + to_width = IntegerType(to_type).width + operand_type = operand.type + if _is_floating_point_type(operand_type): + if is_unsigned_cast: + return arith.FPToUIOp(to_type, operand).result + return arith.FPToSIOp(to_type, operand).result + if _is_index_type(operand_type): + return arith.IndexCastOp(to_type, operand).result + # Assume integer. + from_width = IntegerType(operand_type).width + if to_width > from_width: + if is_unsigned_cast: + return arith.ExtUIOp(to_type, operand).result + return arith.ExtSIOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncIOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _cast_to_floating_point( + self, to_type: Type, operand: Value, is_unsigned_cast: bool + ) -> Value: + operand_type = operand.type + if _is_integer_type(operand_type): + if is_unsigned_cast: + return arith.UIToFPOp(to_type, operand).result + return arith.SIToFPOp(to_type, operand).result + # Assume FloatType. + to_width = _get_floating_point_width(to_type) + from_width = _get_floating_point_width(operand_type) + if to_width > from_width: + return arith.ExtFOp(to_type, operand).result + elif to_width < from_width: + return arith.TruncFOp(to_type, operand).result + raise ValueError( + f"Unable to cast body expression from {operand_type} to " f"{to_type}" + ) + + def _type_cast_signed(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, False) + + def _type_cast_unsigned(self, type_var_name: str, operand: Value) -> Value: + return self._cast(type_var_name, operand, True) + + def _unary_exp(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.ExpOp(x).result + raise NotImplementedError("Unsupported 'exp' operand: {x}") + + def _unary_log(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.LogOp(x).result + raise NotImplementedError("Unsupported 'log' operand: {x}") + + def _unary_abs(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.AbsFOp(x).result + raise NotImplementedError("Unsupported 'abs' operand: {x}") + + def _unary_ceil(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.CeilOp(x).result + raise NotImplementedError("Unsupported 'ceil' operand: {x}") + + def _unary_floor(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return math.FloorOp(x).result + raise NotImplementedError("Unsupported 'floor' operand: {x}") + + def _unary_negf(self, x: Value) -> Value: + if _is_floating_point_type(x.type): + return arith.NegFOp(x).result + if _is_complex_type(x.type): + return complex.NegOp(x).result + raise NotImplementedError("Unsupported 'negf' operand: {x}") + + def _binary_add(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.AddFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.AddIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.AddOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'add' operands: {lhs}, {rhs}") + + def _binary_sub(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.SubFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.SubIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.SubOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'sub' operands: {lhs}, {rhs}") + + def _binary_mul(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MulFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MulIOp(lhs, rhs).result + if _is_complex_type(lhs.type): + return complex.MulOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'mul' operands: {lhs}, {rhs}") + + def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") + + def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MaxFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MaxUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") + + def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinSIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") + + def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: + if _is_floating_point_type(lhs.type): + return arith.MinFOp(lhs, rhs).result + if _is_integer_type(lhs.type) or _is_index_type(lhs.type): + return arith.MinUIOp(lhs, rhs).result + raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") def _infer_structured_outs( op_config: LinalgStructuredOpConfig, - in_arg_defs: Sequence[OperandDefConfig], ins: Sequence[Value], + in_arg_defs: Sequence[OperandDefConfig], + ins: Sequence[Value], out_arg_defs: Sequence[OperandDefConfig], - outs: Union[Sequence[Value], OpResultList]) -> Tuple[ValueList, List[Type]]: - """Infers implicit outs and output types. + outs: Union[Sequence[Value], OpResultList], +) -> Tuple[ValueList, List[Type]]: + """Infers implicit outs and output types. - Respects existing contents of outs if not empty. + Respects existing contents of outs if not empty. - Returns: - normalized outs, output types - """ - # If outs were explicitly provided, we accept them verbatim. - if outs: - return outs, [out.type for out in outs] + Returns: + normalized outs, output types + """ + # If outs were explicitly provided, we accept them verbatim. + if outs: + return outs, [out.type for out in outs] - raise NotImplementedError(f"Output tensor inference not yet supported for " - "structured ops") + raise NotImplementedError( + f"Output tensor inference not yet supported for " "structured ops" + ) def _get_types_from_values(*values: Value) -> Sequence[Type]: - types = [] - for v in values: - types.append(v.type) - return types + types = [] + for v in values: + types.append(v.type) + return types def _get_operand_def_names(*operand_configs: OperandDefConfig) -> Sequence[str]: - return [odc.operand_def.name for odc in operand_configs] - - -def _add_type_mapping(operand_config: OperandDefConfig, operand_type: Type, - type_mapping: Dict[str, Type], - block_arg_types: Sequence[Type]): - element_or_self_type = operand_type - # Get the element type for tensor operands and the type itself for scalars. - if operand_config.shape_map: - try: - element_or_self_type = ShapedType(operand_type).element_type - except Exception as e: - raise ValueError(f"Expected ShapedType but got {operand_type}") from e - name = operand_config.type_var.name - if name in type_mapping: - if type_mapping[name] != element_or_self_type: - raise ValueError(f"Cannot overwrite type mapping {name} = " - f"{type_mapping[name]} by type {element_or_self_type}") - type_mapping[name] = element_or_self_type - block_arg_types.append(element_or_self_type) + return [odc.operand_def.name for odc in operand_configs] + + +def _add_type_mapping( + operand_config: OperandDefConfig, + operand_type: Type, + type_mapping: Dict[str, Type], + block_arg_types: Sequence[Type], +): + element_or_self_type = operand_type + # Get the element type for tensor operands and the type itself for scalars. + if operand_config.shape_map: + try: + element_or_self_type = ShapedType(operand_type).element_type + except Exception as e: + raise ValueError(f"Expected ShapedType but got {operand_type}") from e + name = operand_config.type_var.name + if name in type_mapping: + if type_mapping[name] != element_or_self_type: + raise ValueError( + f"Cannot overwrite type mapping {name} = " + f"{type_mapping[name]} by type {element_or_self_type}" + ) + type_mapping[name] = element_or_self_type + block_arg_types.append(element_or_self_type) def _is_complex_type(t: Type) -> bool: - return ComplexType.isinstance(t) + return ComplexType.isinstance(t) def _is_floating_point_type(t: Type) -> bool: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - return (F64Type.isinstance(t) or F32Type.isinstance(t) or - F16Type.isinstance(t) or BF16Type.isinstance(t)) + # TODO: Create a FloatType in the Python API and implement the switch + # there. + return ( + F64Type.isinstance(t) + or F32Type.isinstance(t) + or F16Type.isinstance(t) + or BF16Type.isinstance(t) + ) def _is_integer_type(t: Type) -> bool: - return IntegerType.isinstance(t) + return IntegerType.isinstance(t) def _is_index_type(t: Type) -> bool: - return IndexType.isinstance(t) + return IndexType.isinstance(t) def _get_floating_point_width(t: Type) -> int: - # TODO: Create a FloatType in the Python API and implement the switch - # there. - if F64Type.isinstance(t): - return 64 - if F32Type.isinstance(t): - return 32 - if F16Type.isinstance(t): - return 16 - if BF16Type.isinstance(t): - return 16 - raise NotImplementedError(f"Unhandled floating point type switch {t}") + # TODO: Create a FloatType in the Python API and implement the switch + # there. + if F64Type.isinstance(t): + return 64 + if F32Type.isinstance(t): + return 32 + if F16Type.isinstance(t): + return 16 + if BF16Type.isinstance(t): + return 16 + raise NotImplementedError(f"Unhandled floating point type switch {t}") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py index aa894dc10..86853994c 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py @@ -30,123 +30,137 @@ class ScalarFn: - """A type of ScalarExpression that applies a function.""" - - def __init__(self, kind: "FunctionKind", fn_name: Optional[str], - attr_name: Optional[str], type_var: Optional["TypeVar"], - operands: Sequence["ScalarExpression"]): - if bool(fn_name) + bool(attr_name) != 1: - raise ValueError("One of 'fn_name', 'attr_name' must be specified") - self.kind = kind - self.fn_name = fn_name - self.attr_name = attr_name - self.type_var = type_var - self.operands = operands - - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_fn=self) - - def __repr__(self): - name = self.fn_name if self.fn_name else self.attr_name - return (f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " - f"operands=[{', '.join(self.operands)}])") + """A type of ScalarExpression that applies a function.""" + + def __init__( + self, + kind: "FunctionKind", + fn_name: Optional[str], + attr_name: Optional[str], + type_var: Optional["TypeVar"], + operands: Sequence["ScalarExpression"], + ): + if bool(fn_name) + bool(attr_name) != 1: + raise ValueError("One of 'fn_name', 'attr_name' must be specified") + self.kind = kind + self.fn_name = fn_name + self.attr_name = attr_name + self.type_var = type_var + self.operands = operands + + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_fn=self) + + def __repr__(self): + name = self.fn_name if self.fn_name else self.attr_name + return ( + f"ScalarFn<{self.kind.name}.{name}>(type_var={self.type_var}, " + f"operands=[{', '.join(self.operands)}])" + ) class ScalarArg: - """A type of ScalarExpression that references a named argument.""" + """A type of ScalarExpression that references a named argument.""" - def __init__(self, arg: str): - self.arg = arg + def __init__(self, arg: str): + self.arg = arg - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_arg=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_arg=self) - def __repr__(self): - return f"(ScalarArg({self.arg})" + def __repr__(self): + return f"(ScalarArg({self.arg})" class ScalarConst: - """A type of ScalarExpression representing a constant.""" + """A type of ScalarExpression representing a constant.""" - def __init__(self, value: str): - self.value = value + def __init__(self, value: str): + self.value = value - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_const=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_const=self) - def __repr__(self): - return f"(ScalarConst({self.value})" + def __repr__(self): + return f"(ScalarConst({self.value})" class ScalarIndex: - """A type of ScalarExpression accessing an iteration index.""" + """A type of ScalarExpression accessing an iteration index.""" - def __init__(self, dim: int): - self.dim = dim + def __init__(self, dim: int): + self.dim = dim - def expr(self) -> "ScalarExpression": - return ScalarExpression(scalar_index=self) + def expr(self) -> "ScalarExpression": + return ScalarExpression(scalar_index=self) - def __repr__(self): - return f"(ScalarIndex({self.dim})" + def __repr__(self): + return f"(ScalarIndex({self.dim})" class ScalarExpression(YAMLObject): - """An expression on scalar values. - - Can be one of: - - ScalarFn - - ScalarArg - - ScalarConst - - ScalarIndex - """ - yaml_tag = "!ScalarExpression" - - def __init__(self, - scalar_fn: Optional[ScalarFn] = None, - scalar_arg: Optional[ScalarArg] = None, - scalar_const: Optional[ScalarConst] = None, - scalar_index: Optional[ScalarIndex] = None): - if (bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + - bool(scalar_index)) != 1: - raise ValueError("One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " - "'scalar_index' must be specified") - self.scalar_fn = scalar_fn - self.scalar_arg = scalar_arg - self.scalar_const = scalar_const - self.scalar_index = scalar_index - - def to_yaml_custom_dict(self): - if self.scalar_fn: - scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) - if self.scalar_fn.fn_name: - scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name - if self.scalar_fn.attr_name: - scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name - if self.scalar_fn.type_var: - scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name - scalar_fn_dict["operands"] = list(self.scalar_fn.operands) - return dict(scalar_fn=scalar_fn_dict) - elif self.scalar_arg: - return dict(scalar_arg=self.scalar_arg.arg) - elif self.scalar_const: - return dict(scalar_const=self.scalar_const.value) - elif self.scalar_index: - return dict(scalar_index=self.scalar_index.dim) - else: - raise ValueError(f"Unexpected ScalarExpression type: {self}") + """An expression on scalar values. + + Can be one of: + - ScalarFn + - ScalarArg + - ScalarConst + - ScalarIndex + """ + + yaml_tag = "!ScalarExpression" + + def __init__( + self, + scalar_fn: Optional[ScalarFn] = None, + scalar_arg: Optional[ScalarArg] = None, + scalar_const: Optional[ScalarConst] = None, + scalar_index: Optional[ScalarIndex] = None, + ): + if ( + bool(scalar_fn) + bool(scalar_arg) + bool(scalar_const) + bool(scalar_index) + ) != 1: + raise ValueError( + "One of 'scalar_fn', 'scalar_arg', 'scalar_const', or " + "'scalar_index' must be specified" + ) + self.scalar_fn = scalar_fn + self.scalar_arg = scalar_arg + self.scalar_const = scalar_const + self.scalar_index = scalar_index + + def to_yaml_custom_dict(self): + if self.scalar_fn: + scalar_fn_dict = dict(kind=self.scalar_fn.kind.name.lower()) + if self.scalar_fn.fn_name: + scalar_fn_dict["fn_name"] = self.scalar_fn.fn_name + if self.scalar_fn.attr_name: + scalar_fn_dict["attr_name"] = self.scalar_fn.attr_name + if self.scalar_fn.type_var: + scalar_fn_dict["type_var"] = self.scalar_fn.type_var.name + scalar_fn_dict["operands"] = list(self.scalar_fn.operands) + return dict(scalar_fn=scalar_fn_dict) + elif self.scalar_arg: + return dict(scalar_arg=self.scalar_arg.arg) + elif self.scalar_const: + return dict(scalar_const=self.scalar_const.value) + elif self.scalar_index: + return dict(scalar_index=self.scalar_index.dim) + else: + raise ValueError(f"Unexpected ScalarExpression type: {self}") class ScalarAssign(YAMLObject): - """An assignment to a named argument (LHS of a comprehension).""" - yaml_tag = "!ScalarAssign" + """An assignment to a named argument (LHS of a comprehension).""" + + yaml_tag = "!ScalarAssign" - def __init__(self, arg: str, value: ScalarExpression): - self.arg = arg - self.value = value + def __init__(self, arg: str, value: ScalarExpression): + self.arg = arg + self.value = value - def to_yaml_custom_dict(self): - return dict(arg=self.arg, value=self.value) + def to_yaml_custom_dict(self): + return dict(arg=self.arg, value=self.value) - def __repr__(self): - return f"ScalarAssign({self.arg}, {self.value})" + def __repr__(self): + return f"ScalarAssign({self.arg}, {self.value})" diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py index ddac87287..4f36029b7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/types.py @@ -21,13 +21,11 @@ __all__ = [ "TypeVar", "TV", - # Predefined types. "I32", "I64", "F32", "F64", - # TypeVar aliases. "T", "U", @@ -36,34 +34,34 @@ class TypeVar: - """A replaceable type variable. + """A replaceable type variable. - Type variables are uniqued by name. - """ - ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] + Type variables are uniqued by name. + """ - def __new__(cls, name: str): - existing = cls.ALL_TYPEVARS.get(name) - if existing is not None: - return existing - new = super().__new__(cls) - new.name = name - cls.ALL_TYPEVARS[name] = new - return new + ALL_TYPEVARS = dict() # type: Dict[str, "TypeVar"] - def __repr__(self): - return f"TypeVar({self.name})" + def __new__(cls, name: str): + existing = cls.ALL_TYPEVARS.get(name) + if existing is not None: + return existing + new = super().__new__(cls) + new.name = name + cls.ALL_TYPEVARS[name] = new + return new - @classmethod - def create_expando(cls): - """Create an expando class that creates unique type vars on attr access.""" + def __repr__(self): + return f"TypeVar({self.name})" - class ExpandoTypeVars: + @classmethod + def create_expando(cls): + """Create an expando class that creates unique type vars on attr access.""" - def __getattr__(self, n): - return cls(n) + class ExpandoTypeVars: + def __getattr__(self, n): + return cls(n) - return ExpandoTypeVars() + return ExpandoTypeVars() # Expando access via TV.foo diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py index 1945eea53..1672656b3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/yaml_helper.py @@ -6,11 +6,12 @@ import sys try: - import yaml + import yaml except ModuleNotFoundError as e: - raise ModuleNotFoundError( - f"This tool requires PyYAML but it was not installed. " - f"Recommend: {sys.executable} -m pip install PyYAML") from e + raise ModuleNotFoundError( + f"This tool requires PyYAML but it was not installed. " + f"Recommend: {sys.executable} -m pip install PyYAML" + ) from e __all__ = [ "yaml_dump", @@ -20,35 +21,33 @@ class YAMLObject(yaml.YAMLObject): + @classmethod + def to_yaml(cls, dumper, self): + """Default to a custom dictionary mapping.""" + return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) - @classmethod - def to_yaml(cls, dumper, self): - """Default to a custom dictionary mapping.""" - return dumper.represent_mapping(cls.yaml_tag, self.to_yaml_custom_dict()) + def to_yaml_custom_dict(self): + raise NotImplementedError() - def to_yaml_custom_dict(self): - raise NotImplementedError() - - def as_linalg_yaml(self): - return yaml_dump(self) + def as_linalg_yaml(self): + return yaml_dump(self) def multiline_str_representer(dumper, data): - if len(data.splitlines()) > 1: - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') - else: - return dumper.represent_scalar('tag:yaml.org,2002:str', data) + if len(data.splitlines()) > 1: + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") + else: + return dumper.represent_scalar("tag:yaml.org,2002:str", data) yaml.add_representer(str, multiline_str_representer) def yaml_dump(data, sort_keys=False, **kwargs): - return yaml.dump(data, sort_keys=sort_keys, **kwargs) + return yaml.dump(data, sort_keys=sort_keys, **kwargs) def yaml_dump_all(data, sort_keys=False, explicit_start=True, **kwargs): - return yaml.dump_all(data, - sort_keys=sort_keys, - explicit_start=explicit_start, - **kwargs) + return yaml.dump_all( + data, sort_keys=sort_keys, explicit_start=explicit_start, **kwargs + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9c96868c1..bac22a2e5 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -7,99 +7,113 @@ @linalg_structured_op -def copy(I=TensorDef(T1), - O=TensorDef(U, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Copies the tensor elementwise. +def copy( + I=TensorDef(T1), + O=TensorDef(U, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Copies the tensor elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = cast(U, I[None]) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = cast(U, I[None]) @linalg_structured_op -def elemwise_unary(I=TensorDef(T1), - O=TensorDef(U, output=True), - fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Applies the unary function fun elementwise. +def elemwise_unary( + I=TensorDef(T1), + O=TensorDef(U, output=True), + fun=UnaryFnAttrDef(default=UnaryFn.exp), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the unary function fun elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, I[None])) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, I[None])) @linalg_structured_op -def elemwise_binary(lhs=TensorDef(T1), - rhs=TensorDef(T2), - O=TensorDef(U, output=True), - fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Applies the binary function fun elementwise. +def elemwise_binary( + lhs=TensorDef(T1), + rhs=TensorDef(T2), + O=TensorDef(U, output=True), + fun=BinaryFnAttrDef(default=BinaryFn.add), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Applies the binary function fun elementwise. - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) @linalg_structured_op -def matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) @linalg_structured_op -def matmul_unsigned(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs an unsigned matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( - U, B[D.k, D.n]) +def matmul_unsigned( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs an unsigned matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( + U, B[D.k, D.n] + ) @linalg_structured_op -def quantized_matmul(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. The quantized variant - includes zero-point adjustments for the left and right operands of the - matmul. - """ - domain(D.m, D.n, D.k) - C[D.m, - D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - - TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed(U, B[D.k, D.n]) - - TypeFn.cast_signed(U, BZp)) +def quantized_matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.m, D.n, D.k) + C[D.m, D.n] += (TypeFn.cast_signed(U, A[D.m, D.k]) - TypeFn.cast_signed(U, AZp)) * ( + TypeFn.cast_signed(U, B[D.k, D.n]) - TypeFn.cast_signed(U, BZp) + ) @linalg_structured_op -def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), - rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), - accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True)): - """Performs a matrix-matrix-transpose multiplication of two 4D inputs. +def mmt4d( + lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a matrix-matrix-transpose multiplication of two 4D inputs. Differences from linalg.matmul: * The right hand side is transposed, whence the 't' in 'mmt'. @@ -108,1132 +122,1201 @@ def mmt4d(lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), whence the 2+2=4 dimensions. The inner tile dimensions are identified with '0' suffixes below, for instance the LHS matrix shape (M, K, M0, K0) reads as: MxK tiles, each of shape M0xK0. - """ - domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) - implements(ContractionOpInterface) - accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( - TV.AccumType, lhs[D.m, D.k, D.m0, D.k0]) * TypeFn.cast_signed( - TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) + """ + domain(D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) @linalg_structured_op -def batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, - D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n]) +def batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) @linalg_structured_op -def quantized_batch_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - AZp=ScalarDef(I32), - BZp=ScalarDef(I32), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. The quantized variant - includes zero-point adjustments for the left and right operands of the - matmul. - """ - domain(D.b, D.m, D.n, D.k) - C[D.b, D.m, D.n] += (TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - - TypeFn.cast_signed(U, AZp)) * (TypeFn.cast_signed( - U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) +def quantized_batch_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + AZp=ScalarDef(I32), + BZp=ScalarDef(I32), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. The quantized variant + includes zero-point adjustments for the left and right operands of the + matmul. + """ + domain(D.b, D.m, D.n, D.k) + C[D.b, D.m, D.n] += ( + TypeFn.cast_signed(U, A[D.b, D.m, D.k]) - TypeFn.cast_signed(U, AZp) + ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) @linalg_structured_op -def batch_reduce_matmul(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True)): - """Performs a batch-reduce matrix multiplication of two 3D inputs. - The partial multiplication results are reduced into a 2D output. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed( - U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n])) +def batch_reduce_matmul( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), +): + """Performs a batch-reduce matrix multiplication of two 3D inputs. + The partial multiplication results are reduced into a 2D output. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += TypeFn.cast_signed( + U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + ) @linalg_structured_op -def matvec(A=TensorDef(T1, S.M, S.N), - y=TensorDef(T2, S.N), - x=TensorDef(U, S.M, output=True)): - """Performs a matrix-vector multiplication. +def matvec( + A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) +): + """Performs a matrix-vector multiplication. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n) - implements(ContractionOpInterface) - x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n) + implements(ContractionOpInterface) + x[D.m] += TypeFn.cast_signed(U, A[D.m, D.n]) * TypeFn.cast_signed(U, y[D.n]) @linalg_structured_op -def vecmat(y=TensorDef(T1, S.M), - A=TensorDef(T2, S.M, S.N), - x=TensorDef(U, S.N, output=True)): - """Performs a vector-matrix multiplication. +def vecmat( + y=TensorDef(T1, S.M), A=TensorDef(T2, S.M, S.N), x=TensorDef(U, S.N, output=True) +): + """Performs a vector-matrix multiplication. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.n, D.m) - implements(ContractionOpInterface) - x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.n, D.m) + implements(ContractionOpInterface) + x[D.n] += TypeFn.cast_signed(U, y[D.m]) * TypeFn.cast_signed(U, A[D.m, D.n]) @linalg_structured_op -def batch_matvec(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K), - C=TensorDef(U, Batch, S.M, output=True)): - """Performs a batched matrix-vector multiplication. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.k) - implements(ContractionOpInterface) - C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k]) +def batch_matvec( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.K), + C=TensorDef(U, Batch, S.M, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.k) + implements(ContractionOpInterface) + C[D.b, D.m] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k] + ) @linalg_structured_op -def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, - output=True)): - """Performs a dot product of two vectors to a scalar result. +def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): + """Performs a dot product of two vectors to a scalar result. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ContractionOpInterface) - C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ContractionOpInterface) + C[None] += TypeFn.cast_signed(U, A[D.m]) * TypeFn.cast_signed(U, B[D.m]) @linalg_structured_op -def conv_1d(I=TensorDef(T1, S.OW + S.KW), - K=TensorDef(T2, S.KW), - O=TensorDef(U, S.OW, output=True)): - """Performs 1-D convolution with no channels. +def conv_1d( + I=TensorDef(T1, S.OW + S.KW), + K=TensorDef(T2, S.KW), + O=TensorDef(U, S.OW, output=True), +): + """Performs 1-D convolution with no channels. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.ow, D.kw) - O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed( - U, K[D.kw]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.ow, D.kw) + O[D.ow] += TypeFn.cast_signed(U, I[D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kw]) @linalg_structured_op -def conv_2d(I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KH, S.KW), - O=TensorDef(U, S.OH, S.OW, output=True)): - """Performs 2-D convolution with no channels. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.oh, D.ow, D.kh, D.kw) - O[D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed(U, K[D.kh, D.kw]) +def conv_2d( + I=TensorDef(T1, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KH, S.KW), + O=TensorDef(U, S.OH, S.OW, output=True), +): + """Performs 2-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.oh, D.ow, D.kh, D.kw) + O[D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw]) @linalg_structured_op -def conv_3d(I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), - K=TensorDef(T2, S.KD, S.KH, S.KW), - O=TensorDef(U, S.OD, S.OH, S.OW, output=True)): - """Performs 3-D convolution with no channels. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) - O[D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw]) +def conv_3d( + I=TensorDef(T1, S.OD + S.KD, S.OH + S.KH, S.OW + S.KW), + K=TensorDef(T2, S.KD, S.KH, S.KW), + O=TensorDef(U, S.OD, S.OH, S.OW, output=True), +): + """Performs 3-D convolution with no channels. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.od, D.oh, D.ow, D.kd, D.kh, D.kw) + O[D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.od + D.kd, D.oh + D.kh, D.ow + D.kw] + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw]) @linalg_structured_op -def conv_1d_nwc_wcf(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.f, D.kw, D.c) - O[D.n, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( - U, K[D.kw, D.c, D.f]) +def conv_1d_nwc_wcf( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.f, D.kw, D.c) + O[D.n, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kw, D.c, D.f]) @linalg_structured_op -def conv_1d_ncw_fcw(I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KW), - O=TensorDef(U, S.N, S.F, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs 1-D convolution. - - Layout: - * Input: NCW. - * Kernel: FCW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.f, D.ow, D.c, D.kw) - O[D.n, D.f, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.f, D.c, D.kw]) +def conv_1d_ncw_fcw( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KW), + O=TensorDef(U, S.N, S.F, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs 1-D convolution. + + Layout: + * Input: NCW. + * Kernel: FCW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.ow, D.c, D.kw) + O[D.n, D.f, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kw]) @linalg_structured_op -def conv_2d_nhwc_hwcf(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NHWC. - * Kernel: HWCF. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) +def conv_2d_nhwc_hwcf( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) @linalg_structured_op -def conv_2d_nhwc_fhwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.F, S.KH, S.KW, S.C), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NHWC. - * Kernel: FHWC. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.c]) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) +def conv_2d_nhwc_fhwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) * TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) @linalg_structured_op -def conv_2d_nhwc_hwcf_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, S.C, S.F), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution with zero point offsets. - - Layout: - * Input: NHWC. - * Kernel: HWCF. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) - O[D.n, D.oh, D.ow, - D.f] += (TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) - - TypeFn.cast_signed(U, IZp)) * (TypeFn.cast_signed( - U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) +def conv_2d_nhwc_hwcf_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: HWCF. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def conv_2d_nchw_fchw(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D convolution. - - Layout: - * Input: NCHW. - * Kernel: FCHW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) +def conv_2d_nchw_fchw( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) @linalg_structured_op -def conv_2d_ngchw_fgchw(I=TensorDef(T1, S.N, S.G, S.C, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs 2-D grouped convolution. - - Layout: - * Input: NGCHW. - * Kernel: FGCHW. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) +def conv_2d_ngchw_fgchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: FGCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) @linalg_structured_op -def conv_3d_ndhwc_dhwcf(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.c, D.f]) +def conv_3d_ndhwc_dhwcf( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) @linalg_structured_op -def conv_3d_ndhwc_dhwcf_q(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - S.F, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution with zero point offsets. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. This includes the zero - point offsets common to quantized operations. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.od, D.oh, D.ow, D.f] += (TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) - TypeFn.cast_signed(U, IZp)) * ( - TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) - - TypeFn.cast_signed(U, KZp)) +def conv_3d_ndhwc_dhwcf_q( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.C, S.F), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution with zero point offsets. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.od, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.c, D.f]) + - TypeFn.cast_signed(U, KZp) + ) @linalg_structured_op -def conv_3d_ncdhw_fcdhw(I=TensorDef(T1, S.N, S.C, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), - O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) - O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.f, D.c, D.kd, D.kh, D.kw]) +def conv_3d_ncdhw_fcdhw( + I=TensorDef( + T1, + S.N, + S.C, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.F, S.C, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.F, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.f, D.c, D.kd, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_1d_nwc_wc(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KW, S.IC), - O=TensorDef(U, S.N, S.OW, S.IC, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.kw) - O[D.n, D.ow, D.ic] += \ - TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast_signed(U, K[D.kw, D.ic]) +def depthwise_conv_1d_nwc_wc( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC), + O=TensorDef(U, S.N, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_1d_ncw_cw(I=TensorDef(T1, S.N, S.IC, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KW), - O=TensorDef(U, S.N, S.IC, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.kw) - O[D.n, D.ic, D.ow] += \ - TypeFn.cast_signed(U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW]) * \ - TypeFn.cast_signed(U, K[D.ic, D.kw]) +def depthwise_conv_1d_ncw_cw( + I=TensorDef(T1, S.N, S.IC, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KW), + O=TensorDef(U, S.N, S.IC, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.kw) + O[D.n, D.ic, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kw]) @linalg_structured_op -def depthwise_conv_1d_nwc_wcm(I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, - S.IC), - K=TensorDef(T2, S.KW, S.IC, S.CM), - O=TensorDef(U, S.N, S.OW, S.IC, S.CM, - output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs depth-wise 1-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.ic, D.cm, D.kw) - O[D.n, D.ow, D.ic, D.cm] += \ - TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic]) * \ - TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) +def depthwise_conv_1d_nwc_wcm( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs depth-wise 1-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.ic, D.cm, D.kw) + O[D.n, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) +def depthwise_conv_2d_nhwc_hwc( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_2d_nchw_chw(I=TensorDef(T1, S.N, S.IC, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KH, S.KW), - O=TensorDef(U, - S.N, - S.IC, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + - D.kw * S.DW]) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) +def depthwise_conv_2d_nchw_chw( + I=TensorDef(T1, S.N, S.IC, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.IC, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.ic, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.ic, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.ic, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwc_q(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - output=True), - strides=IndexAttrDef(S.SH, - S.SW, - default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic] += ((TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast_signed(U, IZp)) * - (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - - TypeFn.cast_signed(U, KZp))) +def depthwise_conv_2d_nhwc_hwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - S.CM, - output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, - 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, - D.ic]) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) +def depthwise_conv_2d_nhwc_hwcm( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) * TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def depthwise_conv_2d_nhwc_hwcm_q(I=TensorDef(T1, S.N, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), - IZp=ScalarDef(I32), - KZp=ScalarDef(I32), - O=TensorDef(U, - S.N, - S.OH, - S.OW, - S.IC, - S.CM, - output=True), - strides=IndexAttrDef(S.SH, - S.SW, - default=[1, 1]), - dilations=IndexAttrDef(S.DH, - S.DW, - default=[1, 1])): - """Performs depth-wise 2-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.ic, - D.cm] += ((TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic]) - - TypeFn.cast_signed(U, IZp)) * - (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - - TypeFn.cast_signed(U, KZp))) +def depthwise_conv_2d_nhwc_hwcm_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.IC), + K=TensorDef(T2, S.KH, S.KW, S.IC, S.CM), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.IC, S.CM, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs depth-wise 2-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.ic, D.cm, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.ic, D.cm] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.ic] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.ic, D.cm]) - TypeFn.cast_signed(U, KZp)) @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwc(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.ic]) +def depthwise_conv_3d_ndhwc_dhwc( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic]) @linalg_structured_op -def depthwise_conv_3d_ncdhw_cdhw(I=TensorDef(T1, S.N, S.IC, - S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), - O=TensorDef(U, - S.N, - S.IC, - S.OD, - S.OH, - S.OW, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. Multiplier is set to 1 - which is a special case for most depthwise convolutions. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.ic, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW]) * TypeFn.cast_signed( - U, K[D.ic, D.kd, D.kh, D.kw]) +def depthwise_conv_3d_ncdhw_cdhw( + I=TensorDef( + T1, + S.N, + S.IC, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + ), + K=TensorDef(T2, S.IC, S.KD, S.KH, S.KW), + O=TensorDef(U, S.N, S.IC, S.OD, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. Multiplier is set to 1 + which is a special case for most depthwise convolutions. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.ic, D.od, D.oh, D.ow] += TypeFn.cast_signed( + U, + I[ + D.n, + D.ic, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) * TypeFn.cast_signed(U, K[D.ic, D.kd, D.kh, D.kw]) @linalg_structured_op -def depthwise_conv_3d_ndhwc_dhwcm(I=TensorDef(T1, S.N, - S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.IC), - K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), - O=TensorDef(U, - S.N, - S.OD, - S.OH, - S.OW, - S.CM, - output=True), - strides=IndexAttrDef(S.SD, - S.SH, - S.SW, - default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs depth-wise 3-D convolution. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) - O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.ic]) * TypeFn.cast_signed( - U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) +def depthwise_conv_3d_ndhwc_dhwcm( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.IC, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, S.IC, S.CM), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.CM, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs depth-wise 3-D convolution. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.cm, D.kd, D.kh, D.kw, D.ic) + O[D.n, D.od, D.oh, D.ow, D.ic, D.cm] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.ic, + ], + ) * TypeFn.cast_signed(U, K[D.kd, D.kh, D.kw, D.ic, D.cm]) @linalg_structured_op -def pooling_nhwc_sum(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs sum pooling. - - Layout: - * Input: NHWC. - * Kernel: HW. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_nhwc_sum( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NHWC. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) @linalg_structured_op -def pooling_nchw_sum(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs sum pooling. - - Layout: - * Input: NCHW. - * Kernel: HW. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW]) +def pooling_nchw_sum( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs sum pooling. + + Layout: + * Input: NCHW. + * Kernel: HW. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) @linalg_structured_op -def pooling_nhwc_max(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_max( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nhwc_max_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KH, - S.KW, - index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, - 1])): - """Performs unsigned max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.c] = ReduceFn.max_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_max_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nchw_max(I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) - O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,])) +def pooling_nchw_max( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.C, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.oh, D.ow, D.kh, D.kw) + O[D.n, D.c, D.oh, D.ow] = ReduceFn.max_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) @linalg_structured_op -def pooling_nhwc_min(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1])): - """Performs min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_min( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kh, D.kw]( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) @linalg_structured_op -def pooling_nhwc_min_unsigned(I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KH, - S.KW, - index_dims=[D.kh, D.kw]), - O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), - dilations=IndexAttrDef(S.DH, S.DW, default=[1, - 1])): - """Performs unsigned min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) - O[D.n, D.oh, D.ow, - D.c] = ReduceFn.min_unsigned[D.kh, D.kw](TypeFn.cast_unsigned( - U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nhwc_min_unsigned( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]), + O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned[D.kh, D.kw]( + TypeFn.cast_unsigned( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + ) + @linalg_structured_op -def pooling_nwc_sum(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs sum pooling. - - Layout: - * Input: NWC. - * Kernel: W. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_nwc_sum( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NWC. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] += TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) @linalg_structured_op -def pooling_ncw_sum(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.C, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs sum pooling. - - Layout: - * Input: NCW. - * Kernel: W. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.ow, D.kw) - O[D.n, D.c, D.ow] += TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) +def pooling_ncw_sum( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs sum pooling. + + Layout: + * Input: NCW. + * Kernel: W. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] += TypeFn.cast_signed(U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW]) @linalg_structured_op -def pooling_nwc_max(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_max( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_nwc_max_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs unsigned max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.max_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_max_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.max_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_ncw_max(I=TensorDef(T1, S.N, S.C, - S.OW * S.SW + S.KW * S.DW), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.C, S.OW, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.c, D.ow, D.kw) - O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.c, D.ow * S.SW + D.kw * S.DW,])) +def pooling_ncw_max( + I=TensorDef(T1, S.N, S.C, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.C, S.OW, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.c, D.ow, D.kw) + O[D.n, D.c, D.ow] = ReduceFn.max_signed[[D.kw]]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.c, + D.ow * S.SW + D.kw * S.DW, + ], + ) + ) @linalg_structured_op -def pooling_nwc_min(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, S.KW, index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]](TypeFn.cast_signed( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_nwc_min( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_signed[[D.kw]]( + TypeFn.cast_signed(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_nwc_min_unsigned(I=TensorDef(T1, S.N, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KW, - index_dims=[D.kw]), - O=TensorDef(U, S.N, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SW, default=[1]), - dilations=IndexAttrDef(S.DW, default=[1])): - """Performs unsigned min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.ow, D.c, D.kw) - O[D.n, D.ow, - D.c] = ReduceFn.min_unsigned[[D.kw]](TypeFn.cast_unsigned( - U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c])) - +def pooling_nwc_min_unsigned( + I=TensorDef(T1, S.N, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.KW, index_dims=[D.kw]), + O=TensorDef(U, S.N, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SW, default=[1]), + dilations=IndexAttrDef(S.DW, default=[1]), +): + """Performs unsigned min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.ow, D.c, D.kw) + O[D.n, D.ow, D.c] = ReduceFn.min_unsigned[[D.kw]]( + TypeFn.cast_unsigned(U, I[D.n, D.ow * S.SW + D.kw * S.DW, D.c]) + ) @linalg_structured_op -def pooling_ndhwc_sum(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D sum pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c]) +def pooling_ndhwc_sum( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D sum pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] += TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) @linalg_structured_op -def pooling_ndhwc_max(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D max pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, - D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_ndhwc_max( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D max pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.max_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) @linalg_structured_op -def pooling_ndhwc_min(I=TensorDef(T1, S.N, S.OD * S.SD + S.KD * S.DD, - S.OH * S.SH + S.KH * S.DH, - S.OW * S.SW + S.KW * S.DW, S.C), - K=TensorDef(T2, - S.KD, - S.KH, - S.KW, - index_dims=[D.kd, D.kh, D.kw]), - O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), - strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), - dilations=IndexAttrDef(S.DD, - S.DH, - S.DW, - default=[1, 1, 1])): - """Performs 3D min pooling. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) - O[D.n, D.od, D.oh, D.ow, - D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw](TypeFn.cast_signed( - U, I[D.n, D.od * S.SD + D.kd * S.DD, D.oh * S.SH + D.kh * S.DH, - D.ow * S.SW + D.kw * S.DW, D.c])) +def pooling_ndhwc_min( + I=TensorDef( + T1, + S.N, + S.OD * S.SD + S.KD * S.DD, + S.OH * S.SH + S.KH * S.DH, + S.OW * S.SW + S.KW * S.DW, + S.C, + ), + K=TensorDef(T2, S.KD, S.KH, S.KW, index_dims=[D.kd, D.kh, D.kw]), + O=TensorDef(U, S.N, S.OD, S.OH, S.OW, S.C, output=True), + strides=IndexAttrDef(S.SD, S.SH, S.SW, default=[1, 1, 1]), + dilations=IndexAttrDef(S.DD, S.DH, S.DW, default=[1, 1, 1]), +): + """Performs 3D min pooling. + + Numeric casting is performed on the input operand, promoting it to the same + data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) + O[D.n, D.od, D.oh, D.ow, D.c] = ReduceFn.min_signed[D.kd, D.kh, D.kw]( + TypeFn.cast_signed( + U, + I[ + D.n, + D.od * S.SD + D.kd * S.DD, + D.oh * S.SH + D.kh * S.DH, + D.ow * S.SW + D.kw * S.DW, + D.c, + ], + ) + ) @linalg_structured_op def fill(value=ScalarDef(T1), O=TensorDef(U, output=True)): - """Fills the output tensor with the given value. + """Fills the output tensor with the given value. - Works for arbitrary ranked output tensors since the operation performs scalar - accesses only and is thus rank polymorphic. Numeric casting is performed on - the value operand, promoting it to the same data type as the output. - """ - implements(FillOpInterface) - defines(Canonicalizer) - O[None] = TypeFn.cast_signed(U, value) + Works for arbitrary ranked output tensors since the operation performs scalar + accesses only and is thus rank polymorphic. Numeric casting is performed on + the value operand, promoting it to the same data type as the output. + """ + implements(FillOpInterface) + defines(Canonicalizer) + O[None] = TypeFn.cast_signed(U, value) @linalg_structured_op -def fill_rng_2d(min=ScalarDef(F64), - max=ScalarDef(F64), - seed=ScalarDef(I32), - O=TensorDef(T, S.M, S.N, output=True)): - """Fills the output tensor with pseudo random numbers. - - The operation generations pseudo random numbers using a linear congruential - generator. It provides no guarantees regarding the distribution of the - generated random numbers. Instead of generating the random numbers - sequentially, it instantiates one random number generator per data element - and runs them in parallel. The seed operand and the indices of the data - element seed the random number generation. The min and max operands limit - the range of the generated random numbers. - """ - domain(D.m, D.n) - multiplier = TypeFn.cast_signed(I32, const(1103515245)) - increment = TypeFn.cast_signed(I32, const(12345)) - rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment - rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment - inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) - offset = TypeFn.cast_signed(F64, const(2147483647)) - scaling = (max - min) * inv_range - O[D.m, D.n] = TypeFn.cast_signed( - T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min) +def fill_rng_2d( + min=ScalarDef(F64), + max=ScalarDef(F64), + seed=ScalarDef(I32), + O=TensorDef(T, S.M, S.N, output=True), +): + """Fills the output tensor with pseudo random numbers. + + The operation generations pseudo random numbers using a linear congruential + generator. It provides no guarantees regarding the distribution of the + generated random numbers. Instead of generating the random numbers + sequentially, it instantiates one random number generator per data element + and runs them in parallel. The seed operand and the indices of the data + element seed the random number generation. The min and max operands limit + the range of the generated random numbers. + """ + domain(D.m, D.n) + multiplier = TypeFn.cast_signed(I32, const(1103515245)) + increment = TypeFn.cast_signed(I32, const(12345)) + rand1 = (TypeFn.cast_signed(I32, index(D.m)) + seed) * multiplier + increment + rand2 = (TypeFn.cast_signed(I32, index(D.n)) + rand1) * multiplier + increment + inv_range = TypeFn.cast_signed(F64, const(2.3283064e-10)) + offset = TypeFn.cast_signed(F64, const(2147483647)) + scaling = (max - min) * inv_range + O[D.m, D.n] = TypeFn.cast_signed( + T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min + ) diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index ca0d479f1..980f237b1 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -5,6 +5,8 @@ from ._python_test_ops_gen import * from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType + def register_python_test_dialect(context, load=True): - from .._mlir_libs import _mlirPythonTest - _mlirPythonTest.register_python_test_dialect(context, load) + from .._mlir_libs import _mlirPythonTest + + _mlirPythonTest.register_python_test_dialect(context, load) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 78956c437..b505a490a 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -6,16 +6,18 @@ class FailurePropagationMode(Enum): - """Propagation mode for silenceable errors.""" - PROPAGATE = 1 - SUPPRESS = 2 + """Propagation mode for silenceable errors.""" - def _as_int(self): - if self is FailurePropagationMode.PROPAGATE: - return 1 + PROPAGATE = 1 + SUPPRESS = 2 + + def _as_int(self): + if self is FailurePropagationMode.PROPAGATE: + return 1 + + assert self is FailurePropagationMode.SUPPRESS + return 2 - assert self is FailurePropagationMode.SUPPRESS - return 2 from .._transform_ops_gen import * from ..._mlir_libs._mlirDialectsTransform import * diff --git a/mlir/python/mlir/execution_engine.py b/mlir/python/mlir/execution_engine.py index 262545b9c..4739231c1 100644 --- a/mlir/python/mlir/execution_engine.py +++ b/mlir/python/mlir/execution_engine.py @@ -7,37 +7,37 @@ import ctypes __all__ = [ - "ExecutionEngine", + "ExecutionEngine", ] -class ExecutionEngine(_execution_engine.ExecutionEngine): - def lookup(self, name): - """Lookup a function emitted with the `llvm.emit_c_interface` - attribute and returns a ctype callable. - Raise a RuntimeError if the function isn't found. - """ - func = self.raw_lookup("_mlir_ciface_" + name) - if not func: - raise RuntimeError("Unknown function " + name) - prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) - return prototype(func) +class ExecutionEngine(_execution_engine.ExecutionEngine): + def lookup(self, name): + """Lookup a function emitted with the `llvm.emit_c_interface` + attribute and returns a ctype callable. + Raise a RuntimeError if the function isn't found. + """ + func = self.raw_lookup("_mlir_ciface_" + name) + if not func: + raise RuntimeError("Unknown function " + name) + prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p) + return prototype(func) - def invoke(self, name, *ctypes_args): - """Invoke a function with the list of ctypes arguments. - All arguments must be pointers. - Raise a RuntimeError if the function isn't found. - """ - func = self.lookup(name) - packed_args = (ctypes.c_void_p * len(ctypes_args))() - for argNum in range(len(ctypes_args)): - packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) - func(packed_args) + def invoke(self, name, *ctypes_args): + """Invoke a function with the list of ctypes arguments. + All arguments must be pointers. + Raise a RuntimeError if the function isn't found. + """ + func = self.lookup(name) + packed_args = (ctypes.c_void_p * len(ctypes_args))() + for argNum in range(len(ctypes_args)): + packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p) + func(packed_args) - def register_runtime(self, name, ctypes_callback): - """Register a runtime function available to the jitted code - under the provided `name`. The `ctypes_callback` must be a - `CFuncType` that outlives the execution engine. - """ - callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) - self.raw_register_runtime("_mlir_ciface_" + name, callback) + def register_runtime(self, name, ctypes_callback): + """Register a runtime function available to the jitted code + under the provided `name`. The `ctypes_callback` must be a + `CFuncType` that outlives the execution engine. + """ + callback = ctypes.cast(ctypes_callback, ctypes.c_void_p) + self.raw_register_runtime("_mlir_ciface_" + name, callback) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index be065d463..99c21ff9a 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,124 +8,123 @@ # Convenience decorator for registering user-friendly Attribute builders. def register_attribute_builder(kind): + def decorator_builder(func): + AttrBuilder.insert(kind, func) + return func - def decorator_builder(func): - AttrBuilder.insert(kind, func) - return func - - return decorator_builder + return decorator_builder @register_attribute_builder("BoolAttr") def _boolAttr(x, context): - return BoolAttr.get(x, context=context) + return BoolAttr.get(x, context=context) @register_attribute_builder("IndexAttr") def _indexAttr(x, context): - return IntegerAttr.get(IndexType.get(context=context), x) + return IntegerAttr.get(IndexType.get(context=context), x) @register_attribute_builder("I16Attr") def _i16Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) @register_attribute_builder("I32Attr") def _i32Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(32, context=context), x) @register_attribute_builder("I64Attr") def _i64Attr(x, context): - return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) + return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) @register_attribute_builder("SI16Attr") def _si16Attr(x, context): - return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) + return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) @register_attribute_builder("SI32Attr") def _si32Attr(x, context): - return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) + return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) @register_attribute_builder("F32Attr") def _f32Attr(x, context): - return FloatAttr.get_f32(x, context=context) + return FloatAttr.get_f32(x, context=context) @register_attribute_builder("F64Attr") def _f64Attr(x, context): - return FloatAttr.get_f64(x, context=context) + return FloatAttr.get_f64(x, context=context) @register_attribute_builder("StrAttr") def _stringAttr(x, context): - return StringAttr.get(x, context=context) + return StringAttr.get(x, context=context) @register_attribute_builder("SymbolNameAttr") def _symbolNameAttr(x, context): - return StringAttr.get(x, context=context) + return StringAttr.get(x, context=context) @register_attribute_builder("SymbolRefAttr") def _symbolRefAttr(x, context): - return FlatSymbolRefAttr.get(x, context=context) + return FlatSymbolRefAttr.get(x, context=context) @register_attribute_builder("ArrayAttr") def _arrayAttr(x, context): - return ArrayAttr.get(x, context=context) + return ArrayAttr.get(x, context=context) @register_attribute_builder("I32ArrayAttr") def _i32ArrayAttr(x, context): - return ArrayAttr.get([_i32Attr(v, context) for v in x]) + return ArrayAttr.get([_i32Attr(v, context) for v in x]) @register_attribute_builder("I64ArrayAttr") def _i64ArrayAttr(x, context): - return ArrayAttr.get([_i64Attr(v, context) for v in x]) + return ArrayAttr.get([_i64Attr(v, context) for v in x]) @register_attribute_builder("F32ArrayAttr") def _f32ArrayAttr(x, context): - return ArrayAttr.get([_f32Attr(v, context) for v in x]) + return ArrayAttr.get([_f32Attr(v, context) for v in x]) @register_attribute_builder("F64ArrayAttr") def _f64ArrayAttr(x, context): - return ArrayAttr.get([_f64Attr(v, context) for v in x]) + return ArrayAttr.get([_f64Attr(v, context) for v in x]) @register_attribute_builder("DenseI64ArrayAttr") def _denseI64ArrayAttr(x, context): - return DenseI64ArrayAttr.get(x, context=context) + return DenseI64ArrayAttr.get(x, context=context) @register_attribute_builder("TypeAttr") def _typeAttr(x, context): - return TypeAttr.get(x, context=context) + return TypeAttr.get(x, context=context) @register_attribute_builder("TypeArrayAttr") def _typeArrayAttr(x, context): - return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) + return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) try: - import numpy as np + import numpy as np - @register_attribute_builder("IndexElementsAttr") - def _indexElementsAttr(x, context): - return DenseElementsAttr.get( - np.array(x, dtype=np.int64), - type=IndexType.get(context=context), - context=context, - ) + @register_attribute_builder("IndexElementsAttr") + def _indexElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IndexType.get(context=context), + context=context, + ) except ImportError: - pass + pass diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index d70967983..51433d75a 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -9,131 +9,134 @@ class C128(ctypes.Structure): - """A ctype representation for MLIR's Double Complex.""" - _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] + """A ctype representation for MLIR's Double Complex.""" + + _fields_ = [("real", ctypes.c_double), ("imag", ctypes.c_double)] class C64(ctypes.Structure): - """A ctype representation for MLIR's Float Complex.""" - _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] + """A ctype representation for MLIR's Float Complex.""" + + _fields_ = [("real", ctypes.c_float), ("imag", ctypes.c_float)] class F16(ctypes.Structure): - """A ctype representation for MLIR's Float16.""" - _fields_ = [("f16", ctypes.c_int16)] + """A ctype representation for MLIR's Float16.""" + + _fields_ = [("f16", ctypes.c_int16)] # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): - """Converts dtype to ctype.""" - if dtp == np.dtype(np.complex128): - return C128 - if dtp == np.dtype(np.complex64): - return C64 - if dtp == np.dtype(np.float16): - return F16 - return np.ctypeslib.as_ctypes_type(dtp) + """Converts dtype to ctype.""" + if dtp == np.dtype(np.complex128): + return C128 + if dtp == np.dtype(np.complex64): + return C64 + if dtp == np.dtype(np.float16): + return F16 + return np.ctypeslib.as_ctypes_type(dtp) def to_numpy(array): - """Converts ctypes array back to numpy dtype array.""" - if array.dtype == C128: - return array.view("complex128") - if array.dtype == C64: - return array.view("complex64") - if array.dtype == F16: - return array.view("float16") - return array + """Converts ctypes array back to numpy dtype array.""" + if array.dtype == C128: + return array.view("complex128") + if array.dtype == C64: + return array.view("complex64") + if array.dtype == F16: + return array.view("float16") + return array def make_nd_memref_descriptor(rank, dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given rank/dtype, where rank>0.""" - class MemRefDescriptor(ctypes.Structure): - """Builds an empty descriptor for the given rank/dtype, where rank>0.""" + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ("shape", ctypes.c_longlong * rank), + ("strides", ctypes.c_longlong * rank), + ] - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ("shape", ctypes.c_longlong * rank), - ("strides", ctypes.c_longlong * rank), - ] - - return MemRefDescriptor + return MemRefDescriptor def make_zero_d_memref_descriptor(dtype): + class MemRefDescriptor(ctypes.Structure): + """Builds an empty descriptor for the given dtype, where rank=0.""" - class MemRefDescriptor(ctypes.Structure): - """Builds an empty descriptor for the given dtype, where rank=0.""" - - _fields_ = [ - ("allocated", ctypes.c_longlong), - ("aligned", ctypes.POINTER(dtype)), - ("offset", ctypes.c_longlong), - ] + _fields_ = [ + ("allocated", ctypes.c_longlong), + ("aligned", ctypes.POINTER(dtype)), + ("offset", ctypes.c_longlong), + ] - return MemRefDescriptor + return MemRefDescriptor class UnrankedMemRefDescriptor(ctypes.Structure): - """Creates a ctype struct for memref descriptor""" - _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] + """Creates a ctype struct for memref descriptor""" + + _fields_ = [("rank", ctypes.c_longlong), ("descriptor", ctypes.c_void_p)] def get_ranked_memref_descriptor(nparray): - """Returns a ranked memref descriptor for the given numpy array.""" - ctp = as_ctype(nparray.dtype) - if nparray.ndim == 0: - x = make_zero_d_memref_descriptor(ctp)() + """Returns a ranked memref descriptor for the given numpy array.""" + ctp = as_ctype(nparray.dtype) + if nparray.ndim == 0: + x = make_zero_d_memref_descriptor(ctp)() + x.allocated = nparray.ctypes.data + x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) + x.offset = ctypes.c_longlong(0) + return x + + x = make_nd_memref_descriptor(nparray.ndim, ctp)() x.allocated = nparray.ctypes.data x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) x.offset = ctypes.c_longlong(0) - return x + x.shape = nparray.ctypes.shape - x = make_nd_memref_descriptor(nparray.ndim, ctp)() - x.allocated = nparray.ctypes.data - x.aligned = nparray.ctypes.data_as(ctypes.POINTER(ctp)) - x.offset = ctypes.c_longlong(0) - x.shape = nparray.ctypes.shape - - # Numpy uses byte quantities to express strides, MLIR OTOH uses the - # torch abstraction which specifies strides in terms of elements. - strides_ctype_t = ctypes.c_longlong * nparray.ndim - x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) - return x + # Numpy uses byte quantities to express strides, MLIR OTOH uses the + # torch abstraction which specifies strides in terms of elements. + strides_ctype_t = ctypes.c_longlong * nparray.ndim + x.strides = strides_ctype_t(*[x // nparray.itemsize for x in nparray.strides]) + return x def get_unranked_memref_descriptor(nparray): - """Returns a generic/unranked memref descriptor for the given numpy array.""" - d = UnrankedMemRefDescriptor() - d.rank = nparray.ndim - x = get_ranked_memref_descriptor(nparray) - d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) - return d + """Returns a generic/unranked memref descriptor for the given numpy array.""" + d = UnrankedMemRefDescriptor() + d.rank = nparray.ndim + x = get_ranked_memref_descriptor(nparray) + d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) + return d def unranked_memref_to_numpy(unranked_memref, np_dtype): - """Converts unranked memrefs to numpy arrays.""" - ctp = as_ctype(np_dtype) - descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) - val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(val[0].shape), - np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, - ) - return to_numpy(strided_arr) + """Converts unranked memrefs to numpy arrays.""" + ctp = as_ctype(np_dtype) + descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) + val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) + np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(val[0].shape), + np.ctypeslib.as_array(val[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) def ranked_memref_to_numpy(ranked_memref): - """Converts ranked memrefs to numpy arrays.""" - np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape) - strided_arr = np.lib.stride_tricks.as_strided( - np_arr, - np.ctypeslib.as_array(ranked_memref[0].shape), - np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, - ) - return to_numpy(strided_arr) + """Converts ranked memrefs to numpy arrays.""" + np_arr = np.ctypeslib.as_array( + ranked_memref[0].aligned, shape=ranked_memref[0].shape + ) + strided_arr = np.lib.stride_tricks.as_strided( + np_arr, + np.ctypeslib.as_array(ranked_memref[0].shape), + np.ctypeslib.as_array(ranked_memref[0].strides) * np_arr.itemsize, + ) + return to_numpy(strided_arr) From 9c7e925400fd43ec9137158e48cc0cc76a751a57 Mon Sep 17 00:00:00 2001 From: Tres Popp Date: Fri, 26 May 2023 10:17:47 +0200 Subject: [PATCH 473/915] [mlir] Move casting calls from methods to function calls MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The MLIR classes Type/Attribute/Operation/Op/Value support cast/dyn_cast/isa/dyn_cast_or_null functionality through llvm's doCast functionality in addition to defining methods with the same name. This change begins the migration of uses of the method to the corresponding function call as has been decided as more consistent. Note that there still exist classes that only define methods directly, such as AffineExpr, and this does not include work currently to support a functional cast/isa call. Context: - https://mlir.llvm.org/deprecation/ at "Use the free function variants for dyn_cast/cast/isa/…" - Original discussion at https://discourse.llvm.org/t/preferred-casting-style-going-forward/68443 Implementation: This patch updates all remaining uses of the deprecated functionality in mlir/. This was done with clang-tidy as described below and further modifications to GPUBase.td and OpenMPOpsInterfaces.td. Steps are described per line, as comments are removed by git: 0. Retrieve the change from the following to build clang-tidy with an additional check: main...tpopp:llvm-project:tidy-cast-check 1. Build clang-tidy 2. Run clang-tidy over your entire codebase while disabling all checks and enabling the one relevant one. Run on all header files also. 3. Delete .inc files that were also modified, so the next build rebuilds them to a pure state. ``` ninja -C $BUILD_DIR clang-tidy run-clang-tidy -clang-tidy-binary=$BUILD_DIR/bin/clang-tidy -checks='-*,misc-cast-functions'\ -header-filter=mlir/ mlir/* -fix rm -rf $BUILD_DIR/tools/mlir/**/*.inc ``` Differential Revision: https://reviews.llvm.org/D151542 --- mlir/lib/CAPI/Interfaces/Interfaces.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Interfaces/Interfaces.cpp b/mlir/lib/CAPI/Interfaces/Interfaces.cpp index 3144a338f..d3fd6b4c0 100644 --- a/mlir/lib/CAPI/Interfaces/Interfaces.cpp +++ b/mlir/lib/CAPI/Interfaces/Interfaces.cpp @@ -47,7 +47,7 @@ SmallVector unwrapOperands(intptr_t nOperands, MlirValue *operands) { DictionaryAttr unwrapAttributes(MlirAttribute attributes) { DictionaryAttr attributeDict; if (!mlirAttributeIsNull(attributes)) - attributeDict = unwrap(attributes).cast(); + attributeDict = llvm::cast(unwrap(attributes)); return attributeDict; } From ac4ebc4fc044f0caac4602362ac978d46b9026dc Mon Sep 17 00:00:00 2001 From: max Date: Fri, 26 May 2023 10:23:17 -0500 Subject: [PATCH 474/915] [MLIR][python bindings] Add TypeCaster for returning refined types from python APIs depends on D150839 This diff uses `MlirTypeID` to register `TypeCaster`s (i.e., `[](PyType pyType) -> DerivedTy { return pyType; }`) for all concrete types (i.e., `PyConcrete<...>`) that are then queried for (by `MlirTypeID`) and called in `struct type_caster::cast`. The result is that anywhere an `MlirType mlirType` is returned from a python binding, that `mlirType` is automatically cast to the correct concrete type. For example: ``` c0 = arith.ConstantOp(f32, 0.0) # CHECK: F32Type(f32) print(repr(c0.result.type)) unranked_tensor_type = UnrankedTensorType.get(f32) unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result # CHECK: UnrankedTensorType print(type(unranked_tensor.type).__name__) # CHECK: UnrankedTensorType(tensor<*xf32>) print(repr(unranked_tensor.type)) ``` This functionality immediately extends to typed attributes (i.e., `attr.type`). The diff also implements similar functionality for `mlir_type_subclass`es but in a slightly different way - for such types (which have no cpp corresponding `class` or `struct`) the user must provide a type caster in python (similar to how `AttrBuilder` works) or in cpp as a `py::cpp_function`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D150927 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 17 +++++++ mlir/include/mlir-c/Dialect/Transform.h | 2 + mlir/include/mlir-c/IR.h | 3 ++ .../mlir/Bindings/Python/PybindAdaptors.h | 25 +++++++++-- mlir/include/mlir/CAPI/Support.h | 21 +++++++++ mlir/lib/Bindings/Python/DialectTransform.cpp | 3 +- mlir/lib/Bindings/Python/Globals.h | 23 +++++++--- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 +- mlir/lib/Bindings/Python/IRCore.cpp | 39 +++++++++------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 6 +-- mlir/lib/Bindings/Python/IRModule.cpp | 44 +++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 17 ++++--- mlir/lib/Bindings/Python/IRTypes.cpp | 20 +++------ mlir/lib/Bindings/Python/MainModule.cpp | 21 ++++++--- mlir/lib/CAPI/Dialect/Transform.cpp | 4 ++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 4 ++ mlir/lib/CAPI/IR/IR.cpp | 4 ++ mlir/lib/CAPI/IR/Support.cpp | 1 - mlir/python/mlir/dialects/python_test.py | 2 +- mlir/python/mlir/ir.py | 1 + 20 files changed, 200 insertions(+), 61 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 6ebb45808..33332d6a3 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -107,6 +107,23 @@ * delineated). */ #define MLIR_PYTHON_CAPI_FACTORY_ATTR "_CAPICreate" +/** Attribute on MLIR Python objects that expose a function for downcasting the + * corresponding Python object to a subclass if the object is in fact a subclass + * (Concrete or mlir_type_subclass) of ir.Type. The signature of the function + * is: def maybe_downcast(self) -> object where the resulting object will + * (possibly) be an instance of the subclass. + */ +#define MLIR_PYTHON_MAYBE_DOWNCAST_ATTR "maybe_downcast" + +/** Attribute on main C extension module (_mlir) that corresponds to the + * type caster registration binding. The signature of the function is: + * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster, + * bool replace) + * where replace indicates the typeCaster should replace any existing registered + * type casters (such as those for upstream ConcreteTypes). + */ +#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" + /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h index 864dffa3f..0409890b2 100644 --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -33,6 +33,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformOperationType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformOperationTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName); diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 13a3cb013..8253981b3 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -825,6 +825,9 @@ MLIR_CAPI_EXPORTED MlirContext mlirTypeGetContext(MlirType type); /// Gets the type ID of the type. MLIR_CAPI_EXPORTED MlirTypeID mlirTypeGetTypeID(MlirType type); +/// Gets the dialect a type belongs to. +MLIR_CAPI_EXPORTED MlirDialect mlirTypeGetDialect(MlirType type); + /// Checks whether a type is null. static inline bool mlirTypeIsNull(MlirType type) { return !type.ptr; } diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index ccca3aa01..272067a26 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -28,6 +28,7 @@ #include "llvm/ADT/Twine.h" namespace py = pybind11; +using namespace py::literals; // Raw CAPI type casters need to be declared before use, so always include them // first. @@ -272,6 +273,7 @@ struct type_caster { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Type") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); } }; @@ -424,20 +426,24 @@ class mlir_attribute_subclass : public pure_subclass { class mlir_type_subclass : public pure_subclass { public: using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); /// Subclasses by looking up the super-class dynamically. mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction) + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : mlir_type_subclass( scope, typeClassName, isaFunction, - py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type")) {} + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")).attr("Type"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Type super-class. This must /// be used if the subclass is being defined in the same extension module /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_type_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superCls) + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it hard, if not impossible, to properly // call chain to parent `__init__` in pybind11 due to its special handling @@ -471,6 +477,19 @@ class mlir_type_subclass : public pure_subclass { "isinstance", [isaFunction](MlirType other) { return isaFunction(other); }, py::arg("other_type")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction(), + pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); + } } }; diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index f3e8a67e0..e42413dbe 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -44,4 +44,25 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) { DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) +namespace llvm { + +template <> +struct DenseMapInfo { + static inline MlirTypeID getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlirTypeIDCreate(pointer); + } + static inline MlirTypeID getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlirTypeIDCreate(pointer); + } + static inline unsigned getHashValue(const MlirTypeID &val) { + return mlirTypeIDHashValue(val); + } + static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { + return mlirTypeIDEqual(lhs, rhs); + } +}; +} // namespace llvm + #endif // MLIR_CAPI_SUPPORT_H diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index a9db2428c..e4b8cee73 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -36,7 +36,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto operationType = - mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType); + mlir_type_subclass(m, "OperationType", mlirTypeIsATransformOperationType, + mlirTransformOperationTypeGetTypeID); operationType.def_classmethod( "get", [](py::object cls, const std::string &operationName, MlirContext ctx) { diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 45d036896..0fc7614cc 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,12 +9,15 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H +#include #include #include -#include #include "PybindUtils.h" +#include "mlir-c/IR.h" +#include "mlir/CAPI/Support.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" @@ -54,16 +57,18 @@ class PyGlobals { /// entities. void loadDialectModule(llvm::StringRef dialectNamespace); - /// Decorator for registering a custom Dialect class. The class object must - /// have a DIALECT_NAMESPACE attribute. - pybind11::object registerDialectDecorator(pybind11::object pyClass); - /// Adds a user-friendly Attribute builder. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, pybind11::function pyFunc); + /// Adds a user-friendly type caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, + bool replace = false); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -80,6 +85,10 @@ class PyGlobals { std::optional lookupAttributeBuilder(const std::string &attributeKind); + /// Returns the custom type caster for MlirTypeID mlirTypeID. + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -101,6 +110,10 @@ class PyGlobals { llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. llvm::StringMap attributeBuilderMap; + /// Map of MlirTypeID to custom type caster. + llvm::DenseMap typeCasterMap; + /// Cache for map of MlirTypeID to custom type caster. + llvm::DenseMap typeCasterMapCache; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 0ab47cc24..3c7926e78 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -15,6 +15,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace py = pybind11; using namespace mlir; @@ -1023,8 +1024,7 @@ class PyTypeAttribute : public PyConcreteAttribute { py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued Type attribute"); c.def_property_readonly("value", [](PyTypeAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirTypeAttrGetValue(self.get())); + return mlirTypeAttrGetValue(self.get()); }); } }; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index a6bd4d849..ec9066aa1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -25,6 +25,7 @@ #include namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -2121,13 +2122,12 @@ class PyOpResult : public PyConcreteValue { /// Returns the list of types of the values held by container. template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; result.reserve(container.size()); for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back( - PyType(context, mlirValueGetType(container.getElement(i).get()))); + result.push_back(mlirValueGetType(container.getElement(i).get())); } return result; } @@ -3148,11 +3148,8 @@ void mlir::python::populateIRCore(py::module &m) { "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly("type", - [](PyAttribute &self) { - return PyType(self.getContext()->getRef(), - mlirAttributeGetType(self)); - }) + .def_property_readonly( + "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { @@ -3247,7 +3244,7 @@ void mlir::python::populateIRCore(py::module &m) { mlirTypeParseGet(context->get(), toMlirStringRef(typeSpec)); if (mlirTypeIsNull(type)) throw MLIRError("Unable to parse type", errors.take()); - return PyType(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), kContextParseTypeDocstring) @@ -3284,6 +3281,18 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyType &self) { + MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirTypeGetDialect(self)); + if (!typeCaster) + return py::cast(self); + return typeCaster.value()(self); + }) .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) @@ -3387,12 +3396,8 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) - .def_property_readonly("type", - [](PyValue &self) { - return PyType( - self.getParentOperation()->getContext(), - mlirValueGetType(self.get())); - }) + .def_property_readonly( + "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 0a7a25c00..25fcaccd2 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -321,11 +321,7 @@ class PyShapedTypeComponents { py::module_local()) .def_property_readonly( "element_type", - [](PyShapedTypeComponents &self) { - return PyType(PyMlirContext::forContext( - mlirTypeGetContext(self.elementType)), - self.elementType); - }, + [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") .def_static( "get", diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 7c49f20f1..d9a66bce0 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -14,6 +14,7 @@ #include #include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" namespace py = pybind11; using namespace mlir; @@ -72,6 +73,15 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, found = std::move(pyFunc); } +void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, + pybind11::function typeCaster, + bool replace) { + pybind11::object &found = typeCasterMap[mlirTypeID]; + if (found && !found.is_none() && !replace) + throw std::runtime_error("Type caster is already registered"); + found = std::move(typeCaster); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -110,6 +120,39 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + { + // Fast match against the class map first (common case). + const auto foundIt = typeCasterMapCache.find(mlirTypeID); + if (foundIt != typeCasterMapCache.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + + // Attempt to find from the canonical map and cache. + { + const auto foundIt = typeCasterMap.find(mlirTypeID); + if (foundIt != typeCasterMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::object is defined"); + // Positive cache. + typeCasterMapCache[mlirTypeID] = foundIt->second; + return foundIt->second; + } + // Negative cache. + typeCasterMap[mlirTypeID] = py::none(); + return std::nullopt; + } +} + std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { loadDialectModule(dialectNamespace); @@ -164,4 +207,5 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { void PyGlobals::clearImportCache() { loadedDialectModulesCache.clear(); operationClassMapCache.clear(); + typeCasterMapCache.clear(); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index cfa3737cf..013bb7b92 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -13,6 +13,7 @@ #include #include +#include "Globals.h" #include "PybindUtils.h" #include "mlir-c/AffineExpr.h" @@ -868,9 +869,7 @@ class PyConcreteType : public BaseTy { PyConcreteType() = default; PyConcreteType(PyMlirContextRef contextRef, MlirType t) - : BaseTy(std::move(contextRef), t) { - pybind11::implicitly_convertible(); - } + : BaseTy(std::move(contextRef), t) {} PyConcreteType(PyType &orig) : PyConcreteType(orig.getContext(), castFrom(orig)) {} @@ -914,6 +913,13 @@ class PyConcreteType : public BaseTy { return printAccum.join(); }); + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; })); + } + DerivedTy::bindDerived(cls); } @@ -1009,9 +1015,8 @@ class PyConcreteAttribute : public BaseTy { return DerivedTy::isaFunction(otherAttr); }, pybind11::arg("other")); - cls.def_property_readonly("type", [](PyAttribute &attr) { - return PyType(attr.getContext(), mlirAttributeGetType(attr)); - }); + cls.def_property_readonly( + "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 5c089b2f2..25307262b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -334,10 +334,7 @@ class PyComplexType : public PyConcreteType { "Create a complex type"); c.def_property_readonly( "element_type", - [](PyComplexType &self) -> PyType { - MlirType t = mlirComplexTypeGetElementType(self); - return PyType(self.getContext(), t); - }, + [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, "Returns element type."); } }; @@ -351,10 +348,7 @@ class PyShapedType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_property_readonly( "element_type", - [](PyShapedType &self) { - MlirType t = mlirShapedTypeGetElementType(self); - return PyType(self.getContext(), t); - }, + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); c.def_property_readonly( "has_rank", @@ -641,9 +635,8 @@ class PyTupleType : public PyConcreteType { "Create a tuple type"); c.def( "get_type", - [](PyTupleType &self, intptr_t pos) -> PyType { - MlirType t = mlirTupleTypeGetType(self, pos); - return PyType(self.getContext(), t); + [](PyTupleType &self, intptr_t pos) { + return mlirTupleTypeGetType(self, pos); }, py::arg("pos"), "Returns the pos-th type in the tuple type."); c.def_property_readonly( @@ -686,7 +679,7 @@ class PyFunctionType : public PyConcreteType { py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { - types.append(PyType(contextRef, mlirFunctionTypeGetInput(t, i))); + types.append(mlirFunctionTypeGetInput(t, i)); } return types; }, @@ -698,8 +691,7 @@ class PyFunctionType : public PyConcreteType { py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { - types.append( - PyType(contextRef, mlirFunctionTypeGetResult(self, i))); + types.append(mlirFunctionTypeGetResult(self, i)); } return types; }, diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index b32b4186f..cdddfbe50 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -16,6 +16,7 @@ namespace py = pybind11; using namespace mlir; +using namespace py::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- @@ -35,12 +36,12 @@ PYBIND11_MODULE(_mlir, m) { self.getDialectSearchPrefixes().push_back(std::move(moduleName)); self.clearImportCache(); }, - py::arg("module_name")) + "module_name"_a) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, - py::arg("dialect_namespace"), py::arg("dialect_class"), + "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - py::arg("operation_name"), py::arg("operation_class"), + "operation_name"_a, "operation_class"_a, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -58,11 +59,11 @@ PYBIND11_MODULE(_mlir, m) { PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, - py::arg("dialect_class"), + "dialect_class"_a, "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](py::object dialectClass) -> py::cpp_function { + [](const py::object &dialectClass) -> py::cpp_function { return py::cpp_function( [dialectClass](py::object opClass) -> py::object { std::string operationName = @@ -75,9 +76,17 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - py::arg("dialect_class"), + "dialect_class"_a, "Produce a class decorator for registering an Operation class as part of " "a dialect"); + m.def( + MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) { + PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster), + replace); + }, + "typeid"_a, "type_caster"_a, "replace"_a = false, + "Register a type caster for casting MLIR types to custom user types."); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 90594b67a..d3cd4e3d0 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -37,6 +37,10 @@ bool mlirTypeIsATransformOperationType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformOperationTypeGetTypeID(void) { + return wrap(transform::OperationType::getTypeID()); +} + MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName) { return wrap( diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 1925478c6..82c5b5a61 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -324,6 +324,10 @@ MlirType mlirUnrankedTensorTypeGetChecked(MlirLocation loc, return wrap(UnrankedTensorType::getChecked(unwrap(loc), unwrap(elementType))); } +MlirType mlirUnrankedTensorTypeGetElementType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getElementType()); +} + //===----------------------------------------------------------------------===// // Ranked / Unranked MemRef type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c0cf59777..373e01a13 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -832,6 +832,10 @@ MlirTypeID mlirTypeGetTypeID(MlirType type) { return wrap(unwrap(type).getTypeID()); } +MlirDialect mlirTypeGetDialect(MlirType type) { + return wrap(&unwrap(type).getDialect()); +} + bool mlirTypeEqual(MlirType t1, MlirType t2) { return unwrap(t1) == unwrap(t2); } diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index cbfbb5476..ea081b2e9 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -23,7 +23,6 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { //===----------------------------------------------------------------------===// // TypeID API. //===----------------------------------------------------------------------===// - MlirTypeID mlirTypeIDCreate(const void *ptr) { assert(reinterpret_cast(ptr) % 8 == 0 && "ptr must be 8 byte aligned"); diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 980f237b1..8465af048 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestTensorType +from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType def register_python_test_dialect(context, load=True): diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 99c21ff9a..10a0f5bd2 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,6 +4,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug +from ._mlir_libs._mlir import register_type_caster # Convenience decorator for registering user-friendly Attribute builders. From 25b91230371486000075d1eccbdafbce03c4fe2e Mon Sep 17 00:00:00 2001 From: max Date: Fri, 26 May 2023 14:39:03 -0500 Subject: [PATCH 475/915] [MLIR][python bindings] Fix inferReturnTypes + AttrSizedOperandSegments for optional operands Right now `inferTypeOpInterface.inferReturnTypes` fails because there's a cast in there to `py::sequence` which throws a `TypeError` when it tries to cast the `None`s. Note `None`s are inserted into `operands` for omitted operands passed to the generated builder: ``` operands.append(_get_op_result_or_value(start) if start is not None else None) operands.append(_get_op_result_or_value(stop) if stop is not None else None) operands.append(_get_op_result_or_value(step) if step is not None else None) ``` Note also that skipping appending to the list operands doesn't work either because [[ https://github.com/llvm/llvm-project/blob/27c37327da67020f938aabf0f6405f57d688441e/mlir/lib/Bindings/Python/IRCore.cpp#L1585 | build generic ]] checks against the number of operand segments expected. Currently the only way around is to handroll through `ir.Operation.create`. Reviewed By: rkayaith Differential Revision: https://reviews.llvm.org/D151409 --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 25fcaccd2..dd4190016 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -53,6 +53,9 @@ llvm::SmallVector wrapOperands(std::optional operandList) { // Note: as the list may contain other lists this may not be final size. mlirOperands.reserve(operandList->size()); for (const auto &&it : llvm::enumerate(*operandList)) { + if (it.value().is_none()) + continue; + PyValue *val; try { val = py::cast(it.value()); From c662da82b1c5c659ef93136d9cddf85cf8fe883e Mon Sep 17 00:00:00 2001 From: max Date: Tue, 30 May 2023 10:46:55 -0500 Subject: [PATCH 476/915] [MLIR][CAPI] Move `DenseMapInfo` I mistakenly put this in `mlir/CAPI/Support.h` at some point during the flurry of refactoring of `TypeCaster`s but as @jpienaar rightly pointed out, it doesn't belong there. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D151669 --- mlir/include/mlir/CAPI/Support.h | 21 --------------------- mlir/lib/Bindings/Python/PybindUtils.h | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index e42413dbe..f3e8a67e0 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -44,25 +44,4 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) { DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) -namespace llvm { - -template <> -struct DenseMapInfo { - static inline MlirTypeID getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); - return mlirTypeIDCreate(pointer); - } - static inline MlirTypeID getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); - return mlirTypeIDCreate(pointer); - } - static inline unsigned getHashValue(const MlirTypeID &val) { - return mlirTypeIDHashValue(val); - } - static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { - return mlirTypeIDEqual(lhs, rhs); - } -}; -} // namespace llvm - #endif // MLIR_CAPI_SUPPORT_H diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 41de7e9b4..2a8da20be 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -354,4 +354,25 @@ class Sliceable { } // namespace mlir +namespace llvm { + +template <> +struct DenseMapInfo { + static inline MlirTypeID getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlirTypeIDCreate(pointer); + } + static inline MlirTypeID getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlirTypeIDCreate(pointer); + } + static inline unsigned getHashValue(const MlirTypeID &val) { + return mlirTypeIDHashValue(val); + } + static inline bool isEqual(const MlirTypeID &lhs, const MlirTypeID &rhs) { + return mlirTypeIDEqual(lhs, rhs); + } +}; +} // namespace llvm + #endif // MLIR_BINDINGS_PYTHON_PYBINDUTILS_H From 5c703bc0c30ba91cf0ebe58fd6d48ed5bcc89b8d Mon Sep 17 00:00:00 2001 From: wren romano <2998727+wrengr@users.noreply.github.com> Date: Tue, 30 May 2023 13:16:29 -0700 Subject: [PATCH 477/915] [mlir][sparse] Combining `dimOrdering`+`higherOrdering` fields into `dimToLvl` This is a major step along the way towards the new STEA design. While a great deal of this patch is simple renaming, there are several significant changes as well. I've done my best to ensure that this patch retains the previous behavior and error-conditions, even though those are at odds with the eventual intended semantics of the `dimToLvl` mapping. Since the majority of the compiler does not yet support non-permutations, I've also added explicit assertions in places that previously had implicitly assumed it was dealing with permutations. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D151505 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 14 ++++----- .../Bindings/Python/DialectSparseTensor.cpp | 30 ++++++------------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 21 +++++-------- 3 files changed, 22 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 1ff6dc1b8..0ad1a315e 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -52,9 +52,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, - enum MlirSparseTensorDimLevelType const *lvlTypes, - MlirAffineMap dimOrdering, MlirAffineMap higherOrdering, int posWidth, - int crdWidth); + enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl, + int posWidth, int crdWidth); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED intptr_t @@ -64,13 +63,10 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); -/// Returns the dimension-ordering of the `sparse_tensor.encoding` attribute. +/// Returns the dimension-to-level mapping of the `sparse_tensor.encoding` +/// attribute. MLIR_CAPI_EXPORTED MlirAffineMap -mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr); - -/// Returns the higher-ordering of the `sparse_tensor.encoding` attribute. -MLIR_CAPI_EXPORTED MlirAffineMap -mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr); +mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr); /// Returns the position bitwidth of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED int diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 0f0e67604..2e8d53545 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -40,18 +40,16 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .def_classmethod( "get", [](py::object cls, std::vector lvlTypes, - std::optional dimOrdering, - std::optional higherOrdering, int posWidth, - int crdWidth, MlirContext context) { + std::optional dimToLvl, int posWidth, int crdWidth, + MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( context, lvlTypes.size(), lvlTypes.data(), - dimOrdering ? *dimOrdering : MlirAffineMap{nullptr}, - higherOrdering ? *higherOrdering : MlirAffineMap{nullptr}, - posWidth, crdWidth)); + dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth, + crdWidth)); }, - py::arg("cls"), py::arg("lvl_types"), py::arg("dim_ordering"), - py::arg("higher_ordering"), py::arg("pos_width"), - py::arg("crd_width"), py::arg("context") = py::none(), + py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"), + py::arg("pos_width"), py::arg("crd_width"), + py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( "lvl_types", @@ -64,19 +62,9 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return ret; }) .def_property_readonly( - "dim_ordering", + "dim_to_lvl", [](MlirAttribute self) -> std::optional { - MlirAffineMap ret = - mlirSparseTensorEncodingAttrGetDimOrdering(self); - if (mlirAffineMapIsNull(ret)) - return {}; - return ret; - }) - .def_property_readonly( - "higher_ordering", - [](MlirAttribute self) -> std::optional { - MlirAffineMap ret = - mlirSparseTensorEncodingAttrGetHigherOrdering(self); + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetDimToLvl(self); if (mlirAffineMapIsNull(ret)) return {}; return ret; diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 8569acf43..e18da1027 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -45,26 +45,21 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); } -MlirAttribute mlirSparseTensorEncodingAttrGet( - MlirContext ctx, intptr_t lvlRank, - MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimOrdering, - MlirAffineMap higherOrdering, int posWidth, int crdWidth) { +MlirAttribute +mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, + MlirSparseTensorDimLevelType const *lvlTypes, + MlirAffineMap dimToLvl, int posWidth, + int crdWidth) { SmallVector cppLvlTypes; cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); return wrap(SparseTensorEncodingAttr::get( - unwrap(ctx), cppLvlTypes, unwrap(dimOrdering), unwrap(higherOrdering), - posWidth, crdWidth)); + unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth)); } -MlirAffineMap mlirSparseTensorEncodingAttrGetDimOrdering(MlirAttribute attr) { - return wrap(cast(unwrap(attr)).getDimOrdering()); -} - -MlirAffineMap -mlirSparseTensorEncodingAttrGetHigherOrdering(MlirAttribute attr) { - return wrap(cast(unwrap(attr)).getHigherOrdering()); +MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getDimToLvl()); } intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { From cadafe60da9ab50daa1133a30a90e58c50cd5281 Mon Sep 17 00:00:00 2001 From: Rafael Ubal Tena Date: Mon, 5 Jun 2023 11:56:44 -0700 Subject: [PATCH 478/915] Activate OpenMP translation in MLIR execution engine CAPI. We've observed that the MLIR Jit Engine fails when the `omp` dialect is used due to a failure to register OpenMP-related translations. This small patch addresses this issue. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D151577 --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 1075ec460..067cf677e 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -13,6 +13,7 @@ #include "mlir/ExecutionEngine/OptUtils.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" +#include "mlir/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.h" #include "llvm/ExecutionEngine/Orc/Mangling.h" #include "llvm/Support/TargetSelect.h" @@ -33,6 +34,7 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, auto &ctx = *unwrap(op)->getContext(); mlir::registerBuiltinDialectTranslation(ctx); mlir::registerLLVMDialectTranslation(ctx); + mlir::registerOpenMPDialectTranslation(ctx); auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); if (!tmBuilderOrError) { From 67a748b04e17a2cf4d9af967d8b27f74e1c71efc Mon Sep 17 00:00:00 2001 From: max Date: Wed, 31 May 2023 15:52:46 -0500 Subject: [PATCH 479/915] [MLIR][python bindings] TypeCasters for Attributes Differential Revision: https://reviews.llvm.org/D151840 --- mlir/include/mlir-c/BuiltinAttributes.h | 47 +++++++++++ mlir/include/mlir-c/IR.h | 3 + .../mlir/Bindings/Python/PybindAdaptors.h | 1 + mlir/lib/Bindings/Python/IRAttributes.cpp | 78 +++++++++++++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 59 ++++++++------ mlir/lib/Bindings/Python/IRModule.h | 30 +++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 67 ++++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 4 + 8 files changed, 265 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 2e6287939..b760dd0cd 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -45,6 +45,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirAffineMapAttrGet(MlirAffineMap map); /// Returns the affine map wrapped in the given affine map attribute. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an AffineMap attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirAffineMapAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -64,6 +67,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an Array attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirArrayAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -89,6 +95,9 @@ mlirDictionaryAttrGetElement(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, MlirStringRef name); +/// Returns the typeID of a Dictionary attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDictionaryAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -115,6 +124,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFloatAttrDoubleGetChecked(MlirLocation loc, /// the value as double. MLIR_CAPI_EXPORTED double mlirFloatAttrGetValueDouble(MlirAttribute attr); +/// Returns the typeID of a Float attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -142,6 +154,9 @@ MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr); /// is of unsigned type and fits into an unsigned 64-bit integer. MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr); +/// Returns the typeID of an Integer attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -162,6 +177,9 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr); /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Returns the typeID of an IntegerSet attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -185,6 +203,9 @@ mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr); /// the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr); +/// Returns the typeID of an Opaque attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirOpaqueAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -206,6 +227,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirStringAttrTypedGet(MlirType type, /// long as the context in which the attribute lives. MLIR_CAPI_EXPORTED MlirStringRef mlirStringAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a String attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStringAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -239,6 +263,9 @@ mlirSymbolRefAttrGetNumNestedReferences(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of an SymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -256,6 +283,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); +/// Returns the typeID of an FlatSymbolRef attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -270,6 +300,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirTypeAttrGet(MlirType type); /// Returns the type stored in the given type attribute. MLIR_CAPI_EXPORTED MlirType mlirTypeAttrGetValue(MlirAttribute attr); +/// Returns the typeID of a Type attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirTypeAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -280,6 +313,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsAUnit(MlirAttribute attr); /// Creates a unit attribute in the given context. MLIR_CAPI_EXPORTED MlirAttribute mlirUnitAttrGet(MlirContext ctx); +/// Returns the typeID of a Unit attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirUnitAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -306,6 +342,8 @@ MLIR_CAPI_EXPORTED int64_t mlirElementsAttrGetNumElements(MlirAttribute attr); // Dense array attribute. //===----------------------------------------------------------------------===// +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseArrayAttrGetTypeID(void); + /// Checks whether the given attribute is a dense array attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseBoolArray(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseI8Array(MlirAttribute attr); @@ -370,6 +408,9 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseIntElements(MlirAttribute attr); MLIR_CAPI_EXPORTED bool mlirAttributeIsADenseFPElements(MlirAttribute attr); +/// Returns the typeID of an DenseIntOrFPElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void); + /// Creates a dense elements attribute with the given Shaped type and elements /// in the same context as the type. MLIR_CAPI_EXPORTED MlirAttribute mlirDenseElementsAttrGet( @@ -612,6 +653,9 @@ mlirSparseElementsAttrGetIndices(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr); +/// Returns the typeID of a SparseElements attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirSparseElementsAttrGetTypeID(void); + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -635,6 +679,9 @@ mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr); MLIR_CAPI_EXPORTED int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos); +/// Returns the typeID of a StridedLayout attribute. +MLIR_CAPI_EXPORTED MlirTypeID mlirStridedLayoutAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 8253981b3..6b5d8cc4b 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -860,6 +860,9 @@ MLIR_CAPI_EXPORTED MlirType mlirAttributeGetType(MlirAttribute attribute); /// Gets the type id of the attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirAttributeGetTypeID(MlirAttribute attribute); +/// Gets the dialect of the attribute. +MLIR_CAPI_EXPORTED MlirDialect mlirAttributeGetDialect(MlirAttribute attribute); + /// Checks whether an attribute is null. static inline bool mlirAttributeIsNull(MlirAttribute attr) { return !attr.ptr; } diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 272067a26..44a10d619 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -97,6 +97,7 @@ struct type_caster { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Attribute") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); } }; diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 3c7926e78..99881b35c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -80,6 +80,8 @@ class PyAffineMapAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; static constexpr const char *pyClassName = "AffineMapAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirAffineMapAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -259,6 +261,8 @@ class PyArrayAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAArray; static constexpr const char *pyClassName = "ArrayAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirArrayAttrGetTypeID; class PyArrayAttributeIterator { public: @@ -339,6 +343,8 @@ class PyFloatAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFloat; static constexpr const char *pyClassName = "FloatAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -406,6 +412,10 @@ class PyIntegerAttribute : public PyConcreteAttribute { return mlirIntegerAttrGetValueUInt(self); }, "Returns the value of the integer attribute"); + c.def_property_readonly_static("static_typeid", + [](py::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } }; @@ -438,6 +448,8 @@ class PyFlatSymbolRefAttribute static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -464,6 +476,8 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAOpaque; static constexpr const char *pyClassName = "OpaqueAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -501,6 +515,8 @@ class PyStringAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAString; static constexpr const char *pyClassName = "StringAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStringAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -921,6 +937,8 @@ class PyDictAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; static constexpr const char *pyClassName = "DictAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirDictionaryAttrGetTypeID; intptr_t dunderLen() { return mlirDictionaryAttrGetNumElements(*this); } @@ -1013,6 +1031,8 @@ class PyTypeAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAType; static constexpr const char *pyClassName = "TypeAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTypeAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1035,6 +1055,8 @@ class PyUnitAttribute : public PyConcreteAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAUnit; static constexpr const char *pyClassName = "UnitAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnitAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1054,6 +1076,8 @@ class PyStridedLayoutAttribute static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAStridedLayout; static constexpr const char *pyClassName = "StridedLayoutAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirStridedLayoutAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1099,6 +1123,50 @@ class PyStridedLayoutAttribute } }; +py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { + if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) + return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + std::string msg = + std::string( + "Can't cast unknown element type DenseIntOrFPElementsAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + +py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { + if (PyBoolAttribute::isaFunction(pyAttribute)) + return py::cast(PyBoolAttribute(pyAttribute)); + if (PyIntegerAttribute::isaFunction(pyAttribute)) + return py::cast(PyIntegerAttribute(pyAttribute)); + std::string msg = + std::string("Can't cast unknown element type DenseArrayAttr (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1118,6 +1186,9 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseF32ArrayAttribute::PyDenseArrayIterator::bind(m); PyDenseF64ArrayAttribute::bind(m); PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseArrayAttrGetTypeID(), + pybind11::cpp_function(denseArrayAttributeCaster)); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); @@ -1125,6 +1196,10 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirDenseIntOrFPElementsAttrGetTypeID(), + pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDictAttribute::bind(m); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1132,6 +1207,9 @@ void mlir::python::populateIRAttributes(py::module &m) { PyIntegerAttribute::bind(m); PyStringAttribute::bind(m); PyTypeAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirIntegerAttrGetTypeID(), + pybind11::cpp_function(integerOrBoolAttributeCaster)); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index ec9066aa1..facd33c72 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2640,10 +2640,7 @@ void mlir::python::populateIRCore(py::module &m) { "Context that owns the Location") .def_property_readonly( "attr", - [](PyLocation &self) { - return PyAttribute(self.getContext(), - mlirLocationGetAttribute(self)); - }, + [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") .def( "emit_error", @@ -3139,7 +3136,7 @@ void mlir::python::populateIRCore(py::module &m) { context->get(), toMlirStringRef(attrSpec)); if (mlirAttributeIsNull(type)) throw MLIRError("Unable to parse attribute", errors.take()); - return PyAttribute(context->getRef(), type); + return type; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " @@ -3175,18 +3172,38 @@ void mlir::python::populateIRCore(py::module &m) { return printAccum.join(); }, "Returns the assembly form of the Attribute.") - .def("__repr__", [](PyAttribute &self) { - // Generally, assembly formats are not printed for __repr__ because - // this can cause exceptionally long debug output and exceptions. - // However, attribute values are generally considered useful and are - // printed. This may need to be re-evaluated if debug dumps end up - // being excessive. - PyPrintAccumulator printAccum; - printAccum.parts.append("Attribute("); - mlirAttributePrint(self, printAccum.getCallback(), - printAccum.getUserData()); - printAccum.parts.append(")"); - return printAccum.join(); + .def("__repr__", + [](PyAttribute &self) { + // Generally, assembly formats are not printed for __repr__ because + // this can cause exceptionally long debug output and exceptions. + // However, attribute values are generally considered useful and + // are printed. This may need to be re-evaluated if debug dumps end + // up being excessive. + PyPrintAccumulator printAccum; + printAccum.parts.append("Attribute("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }) + .def_property_readonly( + "typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional typeCaster = + PyGlobals::get().lookupTypeCaster(mlirTypeID, + mlirAttributeGetDialect(self)); + if (!typeCaster) + return py::cast(self); + return typeCaster.value()(self); }); //---------------------------------------------------------------------------- @@ -3216,13 +3233,7 @@ void mlir::python::populateIRCore(py::module &m) { "The name of the NamedAttribute binding") .def_property_readonly( "attr", - [](PyNamedAttribute &self) { - // TODO: When named attribute is removed/refactored, also remove - // this constructor (it does an inefficient table lookup). - auto contextRef = PyMlirContext::forContext( - mlirAttributeGetContext(self.namedAttr.attribute)); - return PyAttribute(std::move(contextRef), self.namedAttr.attribute); - }, + [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 013bb7b92..225580f0f 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -986,6 +986,8 @@ class PyConcreteAttribute : public BaseTy { // const char *pyClassName using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteAttribute() = default; PyConcreteAttribute(PyMlirContextRef contextRef, MlirAttribute attr) @@ -1017,6 +1019,34 @@ class PyConcreteAttribute : public BaseTy { pybind11::arg("other")); cls.def_property_readonly( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + if (DerivedTy::getTypeIdFunction) + return DerivedTy::getTypeIdFunction(); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + }); + cls.def_property_readonly("typeid", [](PyAttribute &self) { + return py::cast(self).attr("typeid").cast(); + }); + cls.def("__repr__", [](DerivedTy &self) { + PyPrintAccumulator printAccum; + printAccum.parts.append(DerivedTy::pyClassName); + printAccum.parts.append("("); + mlirAttributePrint(self, printAccum.getCallback(), + printAccum.getUserData()); + printAccum.parts.append(")"); + return printAccum.join(); + }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerTypeCaster( + DerivedTy::getTypeIdFunction(), + pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + })); + } + DerivedTy::bindDerived(cls); } diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index f2441e0b0..289913d4f 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -44,6 +44,10 @@ MlirAffineMap mlirAffineMapAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirAffineMapAttrGetTypeID(void) { + return wrap(AffineMapAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Array attribute. //===----------------------------------------------------------------------===// @@ -68,6 +72,8 @@ MlirAttribute mlirArrayAttrGetElement(MlirAttribute attr, intptr_t pos) { return wrap(llvm::cast(unwrap(attr)).getValue()[pos]); } +MlirTypeID mlirArrayAttrGetTypeID(void) { return wrap(ArrayAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Dictionary attribute. //===----------------------------------------------------------------------===// @@ -102,6 +108,10 @@ MlirAttribute mlirDictionaryAttrGetElementByName(MlirAttribute attr, return wrap(llvm::cast(unwrap(attr)).get(unwrap(name))); } +MlirTypeID mlirDictionaryAttrGetTypeID(void) { + return wrap(DictionaryAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Floating point attribute. //===----------------------------------------------------------------------===// @@ -124,6 +134,8 @@ double mlirFloatAttrGetValueDouble(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getValueAsDouble(); } +MlirTypeID mlirFloatAttrGetTypeID(void) { return wrap(FloatAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Integer attribute. //===----------------------------------------------------------------------===// @@ -148,6 +160,10 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) { return llvm::cast(unwrap(attr)).getUInt(); } +MlirTypeID mlirIntegerAttrGetTypeID(void) { + return wrap(IntegerAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Bool attribute. //===----------------------------------------------------------------------===// @@ -172,6 +188,10 @@ bool mlirAttributeIsAIntegerSet(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } +MlirTypeID mlirIntegerSetAttrGetTypeID(void) { + return wrap(IntegerSetAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// @@ -197,6 +217,10 @@ MlirStringRef mlirOpaqueAttrGetData(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getAttrData()); } +MlirTypeID mlirOpaqueAttrGetTypeID(void) { + return wrap(OpaqueAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // String attribute. //===----------------------------------------------------------------------===// @@ -217,6 +241,10 @@ MlirStringRef mlirStringAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirStringAttrGetTypeID(void) { + return wrap(StringAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -257,6 +285,10 @@ MlirAttribute mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, llvm::cast(unwrap(attr)).getNestedReferences()[pos]); } +MlirTypeID mlirSymbolRefAttrGetTypeID(void) { + return wrap(SymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// @@ -273,6 +305,10 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { + return wrap(FlatSymbolRefAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// @@ -289,6 +325,8 @@ MlirType mlirTypeAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } +MlirTypeID mlirTypeAttrGetTypeID(void) { return wrap(TypeAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Unit attribute. //===----------------------------------------------------------------------===// @@ -301,6 +339,8 @@ MlirAttribute mlirUnitAttrGet(MlirContext ctx) { return wrap(UnitAttr::get(unwrap(ctx))); } +MlirTypeID mlirUnitAttrGetTypeID(void) { return wrap(UnitAttr::getTypeID()); } + //===----------------------------------------------------------------------===// // Elements attributes. //===----------------------------------------------------------------------===// @@ -329,8 +369,13 @@ int64_t mlirElementsAttrGetNumElements(MlirAttribute attr) { // Dense array attribute. //===----------------------------------------------------------------------===// +MlirTypeID mlirDenseArrayAttrGetTypeID() { + return wrap(DenseArrayAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseBoolArray(MlirAttribute attr) { return llvm::isa(unwrap(attr)); @@ -356,6 +401,7 @@ bool mlirAttributeIsADenseF64Array(MlirAttribute attr) { //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseBoolArrayGet(MlirContext ctx, intptr_t size, int const *values) { @@ -395,6 +441,7 @@ MlirAttribute mlirDenseF64ArrayGet(MlirContext ctx, intptr_t size, //===----------------------------------------------------------------------===// // Accessors. +//===----------------------------------------------------------------------===// intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { return llvm::cast(unwrap(attr)).size(); @@ -402,6 +449,7 @@ intptr_t mlirDenseArrayGetNumElements(MlirAttribute attr) { //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseBoolArrayGetElement(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr))[pos]; @@ -431,19 +479,27 @@ double mlirDenseF64ArrayGetElement(MlirAttribute attr, intptr_t pos) { //===----------------------------------------------------------------------===// // IsA support. +//===----------------------------------------------------------------------===// bool mlirAttributeIsADenseElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseIntElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } + bool mlirAttributeIsADenseFPElements(MlirAttribute attr) { return llvm::isa(unwrap(attr)); } +MlirTypeID mlirDenseIntOrFPElementsAttrGetTypeID(void) { + return wrap(DenseIntOrFPElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Constructors. +//===----------------------------------------------------------------------===// MlirAttribute mlirDenseElementsAttrGet(MlirType shapedType, intptr_t numElements, @@ -620,6 +676,7 @@ MlirAttribute mlirDenseElementsAttrReshapeGet(MlirAttribute attr, //===----------------------------------------------------------------------===// // Splat accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrIsSplat(MlirAttribute attr) { return llvm::cast(unwrap(attr)).isSplat(); @@ -663,6 +720,7 @@ MlirStringRef mlirDenseElementsAttrGetStringSplatValue(MlirAttribute attr) { //===----------------------------------------------------------------------===// // Indexed accessors. +//===----------------------------------------------------------------------===// bool mlirDenseElementsAttrGetBoolValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; @@ -705,6 +763,7 @@ MlirStringRef mlirDenseElementsAttrGetStringValue(MlirAttribute attr, //===----------------------------------------------------------------------===// // Raw data accessors. +//===----------------------------------------------------------------------===// const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { return static_cast( @@ -876,6 +935,10 @@ MlirAttribute mlirSparseElementsAttrGetValues(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValues()); } +MlirTypeID mlirSparseElementsAttrGetTypeID(void) { + return wrap(SparseElementsAttr::getTypeID()); +} + //===----------------------------------------------------------------------===// // Strided layout attribute. //===----------------------------------------------------------------------===// @@ -903,3 +966,7 @@ intptr_t mlirStridedLayoutAttrGetNumStrides(MlirAttribute attr) { int64_t mlirStridedLayoutAttrGetStride(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getStrides()[pos]; } + +MlirTypeID mlirStridedLayoutAttrGetTypeID(void) { + return wrap(StridedLayoutAttr::getTypeID()); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 373e01a13..16b333afc 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -870,6 +870,10 @@ MlirTypeID mlirAttributeGetTypeID(MlirAttribute attr) { return wrap(unwrap(attr).getTypeID()); } +MlirDialect mlirAttributeGetDialect(MlirAttribute attr) { + return wrap(&unwrap(attr).getDialect()); +} + bool mlirAttributeEqual(MlirAttribute a1, MlirAttribute a2) { return unwrap(a1) == unwrap(a2); } From 49eedcdbcdaaa321a2edc38a8c62b38ac0db795a Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Fri, 16 Jun 2023 17:11:43 -0500 Subject: [PATCH 480/915] [MLIR] Register all extensions in CAPI's RegisterEverything The patch for promised interfaces (a5ef51d7) doesn't register all extensions in the CAPI's `mlirRegisterAllDialects()` function. This is used by the MLIR Python bindings, causing downstream users of the Python bindings to terminate abruptly. This patch adds the call to register all extensions. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D153174 --- mlir/lib/CAPI/RegisterEverything/CMakeLists.txt | 2 ++ mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 55fe49bce..8b9a39558 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp @@ -9,6 +10,7 @@ add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything ${dialect_libs} ${translation_libs} ${conversion_libs} + ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp index e4a751643..b63899bd5 100644 --- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -10,12 +10,14 @@ #include "mlir/CAPI/IR.h" #include "mlir/InitAllDialects.h" +#include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" void mlirRegisterAllDialects(MlirDialectRegistry registry) { mlir::registerAllDialects(*unwrap(registry)); + mlir::registerAllExtensions(*unwrap(registry)); } void mlirRegisterAllLLVMTranslations(MlirContext context) { From 9eed0ac067cd0f0ea2ad68aec0fc15128665ac5c Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 21 Jun 2023 09:40:31 -0700 Subject: [PATCH 481/915] Fix a memory leak in the Python implementation of bytecode writer The bytecode writer config was heap-allocated, but was never freed, causing ASAN errors. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D153440 --- mlir/lib/Bindings/Python/IRCore.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index facd33c72..da8a58de7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1171,6 +1171,7 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, mlirBytecodeWriterConfigDesiredEmitVersion(config, *bytecodeVersion); MlirLogicalResult res = mlirOperationWriteBytecodeWithConfig( operation, config, accum.getCallback(), accum.getUserData()); + mlirBytecodeWriterConfigDestroy(config); if (mlirLogicalResultIsFailure(res)) throw py::value_error((Twine("Unable to honor desired bytecode version ") + Twine(*bytecodeVersion)) From a5f0243645be846bb273aa76c82fda739fa864ed Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 26 Jun 2023 19:43:58 +0000 Subject: [PATCH 482/915] [mlir][linalg] Add missing op to match the generated file D141430 added the generated yaml file for (batch_)?matmul_transpose_b ops, but the source of truth core_named_ops.py was not updated. This change fixes .py file to generate the same result as the yaml file. Differential revision: https://reviews.llvm.org/D150059 Authored-by: kon72 --- .../linalg/opdsl/ops/core_named_ops.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index bac22a2e5..263d2109b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -107,6 +107,22 @@ def quantized_matmul( ) +@linalg_structured_op +def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + + @linalg_structured_op def mmt4d( lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), @@ -148,6 +164,23 @@ def batch_matmul( ) +@linalg_structured_op +def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, + D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k]) + + @linalg_structured_op def quantized_batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), From 538ee5aa8b0de064f63eccb7eaae64fae4ae57a1 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 26 Jun 2023 20:07:51 +0000 Subject: [PATCH 483/915] [mlir][linalg] Add named op for matmul_transpose_a matmul with transposed LHS operand allows better memory access patterns on several architectures including common GPUs. Having a named op for it allows to handle this kind of matmul in a more explicit way. --- .../linalg/opdsl/ops/core_named_ops.py | 32 +++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 263d2109b..4c3e8fb25 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -107,6 +107,22 @@ def quantized_matmul( ) +@linalg_structured_op +def matmul_transpose_a(A=TensorDef(T1, S.K, S.N), + B=TensorDef(T2, S.K, S.M), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed)): + """Performs a matrix multiplication of two 2D inputs with lhs operand + transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) + + @linalg_structured_op def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), B=TensorDef(T2, S.N, S.K), @@ -164,6 +180,22 @@ def batch_matmul( ) +@linalg_structured_op +def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True)): + """Performs a batched matrix multiplication of two 3D inputs where lhs operand + has its non-batch dimensions transposed. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \ + * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + + @linalg_structured_op def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), B=TensorDef(T2, Batch, S.N, S.K), From 6360aa308b5409d5082aab5823d63a920eb1e376 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 4 Jul 2023 14:18:32 +0200 Subject: [PATCH 484/915] [mlir][linalg] Return tensor::PadOp handle from transform op "transform.structured.pad" now returns all `tensor::PadOp` in addition to the padded ops. Also add a test case that shows how to force an allocation for "tensor.pad" ops with a custom memory space. Differential Revision: https://reviews.llvm.org/D153555 --- .../dialects/_structured_transform_ops_ext.py | 75 ++++++++++--------- 1 file changed, 38 insertions(+), 37 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 30dafff6a..47c1bbb31 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -130,43 +130,44 @@ def __init__( class PadOp: - """Specialization for PadOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - padding_values: Optional[ - Optional[Union[ArrayAttr, Sequence[Attribute]]] - ] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[ - Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] - ] = None, - loc=None, - ip=None, - ): - if transpose_paddings is None: - transpose_paddings = [] - if pack_paddings is None: - pack_paddings = [] - if padding_dimensions is None: - padding_dimensions = [] - if padding_values is None: - padding_values = [] - pdl_operation_type = pdl.OperationType.get() - transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) - super().__init__( - pdl_operation_type, - _get_op_result_or_value(target), - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings_attr, - loc=loc, - ip=ip, - ) + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + padding_values: Optional[ + Optional[Union[ArrayAttr, Sequence[Attribute]]] + ] = None, + padding_dimensions: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + loc=None, + ip=None, + ): + if transpose_paddings is None: + transpose_paddings = [] + if pack_paddings is None: + pack_paddings = [] + if padding_dimensions is None: + padding_dimensions = [] + if padding_values is None: + padding_values = [] + pdl_operation_type = pdl.OperationType.get() + transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) + super().__init__( + pdl_operation_type, + pdl_operation_type, + _get_op_result_or_value(target), + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings_attr, + loc=loc, + ip=ip, + ) class ScalarizeOp: From 1dbfafafe6258c534ceafc12669394fb00370dbf Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 4 Jul 2023 14:56:40 +0200 Subject: [PATCH 485/915] [mlir][transform] Improve transform.get_closest_isolated_parent * Rename op to `transform.get_parent_op` * Match parents by "is isolated from above" and/or op name, or just the direct parent. * Deduplication of result payload ops is optional. Differential Revision: https://reviews.llvm.org/D154071 --- .../mlir/dialects/_transform_ops_ext.py | 43 +++++++++++++++---- 1 file changed, 34 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 425ec6585..87f8d398c 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -15,17 +15,42 @@ class CastOp: - def __init__( - self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, _get_op_result_or_value(target), loc=loc, ip=ip + ) -class GetClosestIsolatedParentOp: - def __init__( - self, result_type: Type, target: Union[Operation, Value], *, loc=None, ip=None - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + +class testGetParentOp: + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class MergeHandlesOp: From 1fc60a4d80a2ef66d7038830dd8d3a391fe0d871 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Sun, 2 Jul 2023 14:43:14 +0100 Subject: [PATCH 486/915] [mlir][transform] Allow arbitrary indices to be scalable This change lifts the limitation that only the trailing dimensions/sizes in dynamic index lists can be scalable. It allows us to extend `MaskedVectorizeOp` and `TileOp` from the Transform dialect so that the following is allowed: %1, %loops:3 = transform.structured.tile %0 [4, [4], [4]] This is also a follow up for https://reviews.llvm.org/D153372 that will enable the following (middle vector dimension is scalable): transform.structured.masked_vectorize %0 vector_sizes [2, [4], 8] To facilate this change, the hooks for parsing and printing dynamic index lists are updated accordingly (`printDynamicIndexList` and `parseDynamicIndexList`, respectively). `MaskedVectorizeOp` and `TileOp` are updated to include an array of attribute of bools that captures whether the corresponding vector dimension/tile size, respectively, are scalable or not. NOTE 1: I am re-landing this after the initial version was reverted. To fix the regression and in addition to the original patch, this revision updates the Python bindings for the transform dialect NOTE 2: This change is a part of a larger effort to enable scalable vectorisation in Linalg. See this RFC for more context: * https://discourse.llvm.org/t/rfc-scalable-vectorisation-in-linalg/ This relands 048764f23a380fd6f8cc562a0008dcc6095fb594 with fixes. Differential Revision: https://reviews.llvm.org/D154336 --- .../mlir/dialects/_structured_transform_ops_ext.py | 9 +++++++++ mlir/python/mlir/ir.py | 4 ++++ 2 files changed, 13 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 47c1bbb31..190b3bc91 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -14,6 +14,9 @@ IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] +BoolOrAttrList = Sequence[Union[BoolAttr, bool]] +OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] + def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] @@ -226,6 +229,7 @@ def __init__( Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -240,6 +244,7 @@ def __init__( Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -254,6 +259,7 @@ def __init__( Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] ] = None, interchange: OptionalIntList = None, + scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): @@ -261,6 +267,8 @@ def __init__( interchange = [] if sizes is None: sizes = [] + if scalable_sizes is None: + scalable_sizes = [] static_sizes = [] dynamic_sizes = [] @@ -298,6 +306,7 @@ def __init__( dynamic_sizes=dynamic_sizes, static_sizes=sizes_attr, interchange=interchange, + scalable_sizes=scalable_sizes, loc=loc, ip=ip, ) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 10a0f5bd2..76077acb6 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -105,6 +105,10 @@ def _f64ArrayAttr(x, context): def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) +@register_attribute_builder("DenseBoolArrayAttr") +def _denseBoolArrayAttr(x, context): + return DenseBoolArrayAttr.get(x, context=context) + @register_attribute_builder("TypeAttr") def _typeAttr(x, context): From 59e7efef6db8bda8a80340ccc884ea759da7a003 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Wed, 5 Jul 2023 12:31:37 +0100 Subject: [PATCH 487/915] [MLIR][Linalg] Named op 'add' element-wise This adds the first strict element-wise named op to Linalg. The semantics here is to not allow auto-cast, broadcast semantics and to restrict the operations only to identical types. The remaining semantics must come in the form of surrounding operations on operands, to avoid ambiguity. Examples: ``` // Cast int-to-fp %0 = linalg.copy ins(%in: tensor<32x32xi32>) outs(%out: tensor<32x32xf32>) %1 = linalg.add ins(%arg, %0: tensor<32x32xf32>, tensor<32x32xf32>) outs(%0: tensor<32x32xf32>) // This can be lowered to %1 = linalg.generic {...} ins(%arg, %in: tensor<32x32xf32>, tensor<32x32xi32>) outs(%0: tensor<32x32xf32>) { ^bb0(%a: f32, %i: i32, %out: f32): %f = arith.uitofp %i : f32 %0 = arith.addf %a, %f : f32 linalg.yield %0 : f32 } // Broadcast %0 = linalg.broadcast ins(%in: tensor<32xf32>) init(%out: tensor<32x32xf32>) %1 = linalg.add ins(%arg, %0: tensor<32x32xf32>, tensor<32x32xf32>) outs(%0: tensor<32x32xf32>) // This can be lowered to #bcast_map = affine_map<(d0, d1) -> (d0)> %1 = linalg.generic {... #bcast_map] } ins(%arg, %in: tensor<32x32xf32>, tensor<32xf32>) outs(%0: tensor<32x32xf32>) { ^bb0(%a: f32, %b: f32, %out: f32): %0 = arith.addf %a, %b : f32 linalg.yield %0 : f32 } ``` Once this gets accepted, other arithmetic and maths operations will be added accordingly, with the same semantics. Differential Revision: https://reviews.llvm.org/D154500 --- .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 4c3e8fb25..063165faf 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -51,6 +51,25 @@ def elemwise_binary( O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) +@linalg_structured_op +def add( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """ Adds two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.add` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] + rhs[None] + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From 74f502bf7af4f19818e940a5aed4eb0064c93a64 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Wed, 5 Jul 2023 17:13:42 +0100 Subject: [PATCH 488/915] [MLIR][Linalg] Add more arith named ops to linalg Following up the 'add' named op, here are the remaining basic arithmetic and maths, including a 'div_unsigned' for integer unsigned values. In the same pattern as 'matmul_unsigned', the simply named 'div' assumes signed values and the '_unsigned' variation handles the unsigned values. It's a bit odd, but there doesn't seem to be a easy way to restrict to specific types to make 'div_unsigned' only work with integers in the structured ops framework. Same as 'add', these have strict semantics regarding casts. Unary math ops will need some massaging, so I split these ones for now as I continue working on them. Differential Revision: https://reviews.llvm.org/D154524 --- .../linalg/opdsl/ops/core_named_ops.py | 59 ++++++++++++++++++- 1 file changed, 58 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 063165faf..5c591085e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -57,7 +57,7 @@ def add( rhs=TensorDef(T1), O=TensorDef(T1, output=True), ): - """ Adds two tensors elementwise. + """Adds two tensors elementwise. The shapes and element types must be identical. The appropriate casts, broadcasts and reductions should be done previously to calling this op. @@ -70,6 +70,63 @@ def add( O[None] = lhs[None] + rhs[None] +@linalg_structured_op +def sub( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] - rhs[None] + + +@linalg_structured_op +def mul( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Multiplies two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] * rhs[None] + + +@linalg_structured_op +def div( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] / rhs[None] + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From a2cf4e09874932c189f63b90f0eb96289f3924be Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Wed, 5 Jul 2023 22:02:23 +0100 Subject: [PATCH 489/915] Revert "[MLIR][Linalg] Add more arith named ops to linalg" This reverts commit 74f502bf7af4f19818e940a5aed4eb0064c93a64. It failed on NVidia, AMD and Windows bots. Investigating. --- .../linalg/opdsl/ops/core_named_ops.py | 59 +------------------ 1 file changed, 1 insertion(+), 58 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 5c591085e..063165faf 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -57,7 +57,7 @@ def add( rhs=TensorDef(T1), O=TensorDef(T1, output=True), ): - """Adds two tensors elementwise. + """ Adds two tensors elementwise. The shapes and element types must be identical. The appropriate casts, broadcasts and reductions should be done previously to calling this op. @@ -70,63 +70,6 @@ def add( O[None] = lhs[None] + rhs[None] -@linalg_structured_op -def sub( - lhs=TensorDef(T1), - rhs=TensorDef(T1), - O=TensorDef(T1, output=True), -): - """Subtracts two tensors elementwise. - - The shapes and element types must be identical. The appropriate casts, - broadcasts and reductions should be done previously to calling this op. - - This means reduction/broadcast/element cast semantics is explicit. Further - passes can take that into account when lowering this code. For example, - a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a - `linalg.generic` with different affine maps for the two operands. - """ - O[None] = lhs[None] - rhs[None] - - -@linalg_structured_op -def mul( - lhs=TensorDef(T1), - rhs=TensorDef(T1), - O=TensorDef(T1, output=True), -): - """Multiplies two tensors elementwise. - - The shapes and element types must be identical. The appropriate casts, - broadcasts and reductions should be done previously to calling this op. - - This means reduction/broadcast/element cast semantics is explicit. Further - passes can take that into account when lowering this code. For example, - a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a - `linalg.generic` with different affine maps for the two operands. - """ - O[None] = lhs[None] * rhs[None] - - -@linalg_structured_op -def div( - lhs=TensorDef(T1), - rhs=TensorDef(T1), - O=TensorDef(T1, output=True), -): - """Divides the first tensor by the second tensor, elementwise. - - The shapes and element types must be identical. The appropriate casts, - broadcasts and reductions should be done previously to calling this op. - - This means reduction/broadcast/element cast semantics is explicit. Further - passes can take that into account when lowering this code. For example, - a `linalg.broadcast` + `linalg.div` sequence can be lowered to a - `linalg.generic` with different affine maps for the two operands. - """ - O[None] = lhs[None] / rhs[None] - - @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From 2a193f4d5270ff1faa8a8901c45043c185ebe1c0 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 5 Jul 2023 15:02:59 -0500 Subject: [PATCH 490/915] Add SymbolRefAttr to python bindings Differential Revision: https://reviews.llvm.org/D154541 --- mlir/include/mlir-c/BuiltinAttributes.h | 3 -- mlir/lib/Bindings/Python/IRAttributes.cpp | 64 ++++++++++++++++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 8 +-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 4 -- mlir/python/mlir/ir.py | 9 ++++ 5 files changed, 75 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index b760dd0cd..631981924 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -283,9 +283,6 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MLIR_CAPI_EXPORTED MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr); -/// Returns the typeID of an FlatSymbolRef attribute. -MLIR_CAPI_EXPORTED MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void); - //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 99881b35c..4ee06fa7a 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -442,14 +442,59 @@ class PyBoolAttribute : public PyConcreteAttribute { } }; +class PySymbolRefAttribute : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsASymbolRef; + static constexpr const char *pyClassName = "SymbolRefAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static MlirAttribute fromList(const std::vector &symbols, + PyMlirContext &context) { + if (symbols.empty()) + throw std::runtime_error("SymbolRefAttr must be composed of at least " + "one symbol."); + MlirStringRef rootSymbol = toMlirStringRef(symbols[0]); + SmallVector referenceAttrs; + for (size_t i = 1; i < symbols.size(); ++i) { + referenceAttrs.push_back( + mlirFlatSymbolRefAttrGet(context.get(), toMlirStringRef(symbols[i]))); + } + return mlirSymbolRefAttrGet(context.get(), rootSymbol, + referenceAttrs.size(), referenceAttrs.data()); + } + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::vector &symbols, + DefaultingPyMlirContext context) { + return PySymbolRefAttribute::fromList(symbols, context.resolve()); + }, + py::arg("symbols"), py::arg("context") = py::none(), + "Gets a uniqued SymbolRef attribute from a list of symbol names"); + c.def_property_readonly( + "value", + [](PySymbolRefAttribute &self) { + std::vector symbols = { + unwrap(mlirSymbolRefAttrGetRootReference(self)).str()}; + for (int i = 0; i < mlirSymbolRefAttrGetNumNestedReferences(self); + ++i) + symbols.push_back( + unwrap(mlirSymbolRefAttrGetRootReference( + mlirSymbolRefAttrGetNestedReference(self, i))) + .str()); + return symbols; + }, + "Returns the value of the SymbolRef attribute as a list[str]"); + } +}; + class PyFlatSymbolRefAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAFlatSymbolRef; static constexpr const char *pyClassName = "FlatSymbolRefAttr"; using PyConcreteAttribute::PyConcreteAttribute; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFlatSymbolRefAttrGetTypeID; static void bindDerived(ClassTy &c) { c.def_static( @@ -1167,6 +1212,16 @@ py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { throw py::cast_error(msg); } +py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { + if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + if (PySymbolRefAttribute::isaFunction(pyAttribute)) + return py::cast(PySymbolRefAttribute(pyAttribute)); + std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); +} + } // namespace void mlir::python::populateIRAttributes(py::module &m) { @@ -1201,6 +1256,11 @@ void mlir::python::populateIRAttributes(py::module &m) { pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); PyDictAttribute::bind(m); + PySymbolRefAttribute::bind(m); + PyGlobals::get().registerTypeCaster( + mlirSymbolRefAttrGetTypeID(), + pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); PyFloatAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index da8a58de7..3ab6d57b4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3131,13 +3131,13 @@ void mlir::python::populateIRCore(py::module &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", - [](std::string attrSpec, DefaultingPyMlirContext context) { + [](const std::string &attrSpec, DefaultingPyMlirContext context) { PyMlirContext::ErrorCapture errors(context->getRef()); - MlirAttribute type = mlirAttributeParseGet( + MlirAttribute attr = mlirAttributeParseGet( context->get(), toMlirStringRef(attrSpec)); - if (mlirAttributeIsNull(type)) + if (mlirAttributeIsNull(attr)) throw MLIRError("Unable to parse attribute", errors.take()); - return type; + return attr; }, py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 289913d4f..de221ddbf 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -305,10 +305,6 @@ MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { return wrap(llvm::cast(unwrap(attr)).getValue()); } -MlirTypeID mlirFlatSymbolRefAttrGetTypeID(void) { - return wrap(FlatSymbolRefAttr::getTypeID()); -} - //===----------------------------------------------------------------------===// // Type attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 76077acb6..e36736f29 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -73,6 +73,14 @@ def _symbolNameAttr(x, context): @register_attribute_builder("SymbolRefAttr") def _symbolRefAttr(x, context): + if isinstance(x, list): + return SymbolRefAttr.get(x, context=context) + else: + return FlatSymbolRefAttr.get(x, context=context) + + +@register_attribute_builder("FlatSymbolRefAttr") +def _flatSymbolRefAttr(x, context): return FlatSymbolRefAttr.get(x, context=context) @@ -105,6 +113,7 @@ def _f64ArrayAttr(x, context): def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) + @register_attribute_builder("DenseBoolArrayAttr") def _denseBoolArrayAttr(x, context): return DenseBoolArrayAttr.get(x, context=context) From 24e4e9cc62ea2e39964d2ae08fee4bda6add8719 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 6 Jul 2023 12:20:03 +0200 Subject: [PATCH 491/915] [mlir][linalg][transform] Fix Python build This should have been part of D154585. --- .../dialects/_structured_transform_ops_ext.py | 94 ++++++++++--------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 190b3bc91..b754034c8 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -217,10 +217,10 @@ def __init__( class TileOp: - """Specialization for TileOp class.""" + """Specialization for TileOp class.""" - @overload - def __init__( + @overload + def __init__( self, loop_types: Union[Type, List[Type]], target: Union[Operation, Value], @@ -233,10 +233,10 @@ def __init__( loc=None, ip=None, ): - ... + ... - @overload - def __init__( + @overload + def __init__( self, target: Union[Operation, Value, OpView], *, @@ -248,9 +248,9 @@ def __init__( loc=None, ip=None, ): - ... + ... - def __init__( + def __init__( self, loop_types_or_target: Union[Type, List[Type], Operation, Value], target_or_none: Optional[Union[Operation, Value, OpView]] = None, @@ -263,43 +263,45 @@ def __init__( loc=None, ip=None, ): - if interchange is None: - interchange = [] - if sizes is None: - sizes = [] - if scalable_sizes is None: - scalable_sizes = [] - - static_sizes = [] - dynamic_sizes = [] - if isinstance(sizes, ArrayAttr): - sizes_attr = sizes + if interchange is None: + interchange = [] + if sizes is None: + sizes = [] + + static_sizes = [] + dynamic_sizes = [] + if isinstance(sizes, ArrayAttr): + sizes_attr = sizes + else: + for size in sizes: + if isinstance(size, int): + static_sizes.append(size) else: - for size in sizes: - if isinstance(size, int): - static_sizes.append(size) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = DenseI64ArrayAttr.get(static_sizes) - - num_loops = sum(v if v == 0 else 1 for v in self.__extract_values(sizes_attr)) - - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." - else: - loop_types = ( - ([loop_types_or_target] * num_loops) - if isinstance(loop_types_or_target, Type) - else loop_types_or_target - ) - target = target_or_none + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(_get_op_result_or_value(size)) + sizes_attr = DenseI64ArrayAttr.get(static_sizes) - target = _get_op_result_or_value(target) + num_loops = sum( + v if v == 0 else 1 for v in self.__extract_values(sizes_attr) + ) + if scalable_sizes is None: + scalable_sizes = [False] * len(self.__extract_values(sizes_attr)) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + target = _get_op_result_or_value(target) - super().__init__( + super().__init__( target.type, loop_types, target, @@ -311,10 +313,10 @@ def __init__( ip=ip, ) - def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: - if not attr: - return [] - return [element for element in attr] + def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: + if not attr: + return [] + return [element for element in attr] class VectorizeOp: From 96a39461a292adcad67f5587083a94ecd66d1e9a Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Thu, 6 Jul 2023 12:03:16 +0100 Subject: [PATCH 492/915] [MLIR][Linalg] Add more arith named ops to linalg (take 2) Re-apply 74f502bf7af4 after implementing __truediv__ for TensorUse. [MLIR][Linalg] Add more arith named ops to linalg Following up the 'add' named op, here are the remaining basic arithmetic and maths, including a 'div_unsigned' for integer unsigned values. In the same pattern as 'matmul_unsigned', the simply named 'div' assumes signed values and the '_unsigned' variation handles the unsigned values. It's a bit odd, but there doesn't seem to be a easy way to restrict to specific types to make 'div_unsigned' only work with integers in the structured ops framework. Same as 'add', these have strict semantics regarding casts. Unary math ops will need some massaging, so I split these ones for now as I continue working on them. Differential Revision: https://reviews.llvm.org/D154524 --- .../linalg/opdsl/lang/comprehension.py | 5 ++ .../linalg/opdsl/ops/core_named_ops.py | 79 ++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 5d5866fde..d9698e8ab 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -85,6 +85,9 @@ def __mul__(self, rhs) -> "TensorExpression": def __sub__(self, rhs) -> "TensorExpression": return BinaryFn.sub(self, rhs) + def __truediv__(self, rhs) -> "TensorExpression": + return BinaryFn.div(self, rhs) + def __hash__(self): return hash(id(self)) @@ -321,6 +324,8 @@ class BinaryFn: add = BinaryFnType("add") sub = BinaryFnType("sub") mul = BinaryFnType("mul") + div = BinaryFnType("div") + div_unsigned = BinaryFnType("div_unsigned") max_signed = BinaryFnType("max_signed") min_signed = BinaryFnType("min_signed") max_unsigned = BinaryFnType("max_unsigned") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 063165faf..ae40290eb 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -57,7 +57,7 @@ def add( rhs=TensorDef(T1), O=TensorDef(T1, output=True), ): - """ Adds two tensors elementwise. + """Adds two tensors elementwise. The shapes and element types must be identical. The appropriate casts, broadcasts and reductions should be done previously to calling this op. @@ -70,6 +70,83 @@ def add( O[None] = lhs[None] + rhs[None] +@linalg_structured_op +def sub( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Subtracts two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] - rhs[None] + + +@linalg_structured_op +def mul( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Multiplies two tensors elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] * rhs[None] + + +@linalg_structured_op +def div( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] / rhs[None] + + +@linalg_structured_op +def div_unsigned( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Divides the first tensor by the second tensor, elementwise. For integer + types, performs an unsigned division. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = lhs[None] / rhs[None] + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From f49af8e260c89781ce727c9be0aa2c1afd8f3513 Mon Sep 17 00:00:00 2001 From: Jeremy Furtek Date: Thu, 6 Jul 2023 08:56:05 -0700 Subject: [PATCH 493/915] [mlir] Add support for TF32 as a Builtin FloatType This diff adds support for TF32 as a Builtin floating point type. This supplements the recent addition of the TF32 semantic to the LLVM APFloat class by extending usage to MLIR. https://reviews.llvm.org/D151923 More information on the TF32 type can be found here: https://blogs.nvidia.com/blog/2020/05/14/tensorfloat-32-precision-format/ Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D153705 --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 21 +++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 10 ++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 9 +++++++++ 4 files changed, 50 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index c8ea44cd9..a6d8e10ef 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -163,6 +163,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAF64(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirF64TypeGet(MlirContext ctx); +/// Returns the typeID of a TF32 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloatTF32TypeGetTypeID(void); + +/// Checks whether the given type is an TF32 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsATF32(MlirType type); + +/// Creates a TF32 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirTF32TypeGet(MlirContext ctx); + //===----------------------------------------------------------------------===// // None type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 25307262b..caf215be8 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -247,6 +247,26 @@ class PyF16Type : public PyConcreteType { } }; +/// Floating Point Type subclass - TF32Type. +class PyTF32Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatTF32TypeGetTypeID; + static constexpr const char *pyClassName = "FloatTF32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirTF32TypeGet(context->get()); + return PyTF32Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a tf32 type."); + } +}; + /// Floating Point Type subclass - F32Type. class PyF32Type : public PyConcreteType { public: @@ -754,6 +774,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyFloat8E5M2FNUZType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); + PyTF32Type::bind(m); PyF32Type::bind(m); PyF64Type::bind(m); PyNoneType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 82c5b5a61..50266b4b5 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -152,6 +152,16 @@ MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(FloatType::getF16(unwrap(ctx))); } +MlirTypeID mlirFloatTF32TypeGetTypeID() { + return wrap(FloatTF32Type::getTypeID()); +} + +bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); } + +MlirType mlirTF32TypeGet(MlirContext ctx) { + return wrap(FloatType::getTF32(unwrap(ctx))); +} + MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 714935fe1..23f4687d0 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -56,6 +56,7 @@ __all__ = [ "Float8E4M3B11FNUZType", "Float8E5M2FNUZType", "F16Type", + "FloatTF32Type", "F32Type", "F64Type", "FlatSymbolRefAttr", @@ -627,6 +628,14 @@ class F16Type(Type): @staticmethod def isinstance(arg: Any) -> bool: ... +# TODO: Auto-generated. Audit and fix. +class FloatTF32Type(Type): + def __init__(self, cast_from_type: Type) -> None: ... + @staticmethod + def get(*args, **kwargs) -> FloatTF32Type: ... + @staticmethod + def isinstance(arg: Any) -> bool: ... + # TODO: Auto-generated. Audit and fix. class F32Type(Type): def __init__(self, cast_from_type: Type) -> None: ... From f75ccf8dcb32901b375ee4127f5997ea777a732f Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Thu, 6 Jul 2023 14:56:44 +0100 Subject: [PATCH 494/915] [MLIR][Linalg] Add unary named ops to linalg Following binary arithmetic in previous commits, this patch adds unary maths ops to linalg. It also fixes a few of the previous tests, and makes the binary ops call BinaryFn. directly instead of relying on Python to recognise the operation. Differential Revision: https://reviews.llvm.org/D154618 --- .../linalg/opdsl/ops/core_named_ops.py | 80 ++++++++++++++++++- 1 file changed, 76 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index ae40290eb..9cc252eb7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -35,6 +35,78 @@ def elemwise_unary( O[None] = fun(cast(U, I[None])) +@linalg_structured_op +def exp( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies exp(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.exp(I[None]) + + +@linalg_structured_op +def log( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies log(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.log(I[None]) + + +@linalg_structured_op +def abs( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies abs(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.abs(I[None]) + + +@linalg_structured_op +def ceil( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies ceil(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.ceil(I[None]) + + +@linalg_structured_op +def floor( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies floor(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.floor(I[None]) + + +@linalg_structured_op +def negf( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies negf(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.negf(I[None]) + + @linalg_structured_op def elemwise_binary( lhs=TensorDef(T1), @@ -67,7 +139,7 @@ def add( a `linalg.broadcast` + `linalg.add` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ - O[None] = lhs[None] + rhs[None] + O[None] = BinaryFn.add(lhs[None], rhs[None]) @linalg_structured_op @@ -86,7 +158,7 @@ def sub( a `linalg.broadcast` + `linalg.sub` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ - O[None] = lhs[None] - rhs[None] + O[None] = BinaryFn.sub(lhs[None], rhs[None]) @linalg_structured_op @@ -105,7 +177,7 @@ def mul( a `linalg.broadcast` + `linalg.mul` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ - O[None] = lhs[None] * rhs[None] + O[None] = BinaryFn.mul(lhs[None], rhs[None]) @linalg_structured_op @@ -124,7 +196,7 @@ def div( a `linalg.broadcast` + `linalg.div` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ - O[None] = lhs[None] / rhs[None] + O[None] = BinaryFn.div(lhs[None], rhs[None]) @linalg_structured_op From 25ed3be581af29a4605a394252787980db00c7d5 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 7 Jul 2023 12:04:30 +0100 Subject: [PATCH 495/915] [MLIR][Linalg] Add max named op to linalg I've been trying to come up with a simple and clean implementation for ReLU. TOSA uses `clamp` which is probably the goal, but that means table-gen to make it efficient (attributes, only lower `min` or `max`). For now, `max` is a reasonable named op despite ReLU, so we can start using it for tiling and fusion, and upon success, we create a more complete op `clamp` that doesn't need a whole tensor filled with zeroes or ones to implement the different activation functions. As with other named ops, we start "requiring" type casts and broadcasts, and zero filled constant tensors to a more complex pattern-matcher, and can slowly simplify with attributes or structured matchers (ex. PDL) in the future. Differential Revision: https://reviews.llvm.org/D154703 --- .../linalg/opdsl/ops/core_named_ops.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 9cc252eb7..e4512cd1e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -219,6 +219,25 @@ def div_unsigned( O[None] = lhs[None] / rhs[None] +@linalg_structured_op +def max( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the max (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From bde58286419bfd03dbeb09aee7e8c5298c3f8e75 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Wed, 21 Jun 2023 18:40:51 +0000 Subject: [PATCH 496/915] [mlir][CAPI] Expose the rest of MLIRContext's constructors It's recommended practice that people calling MLIR in a loop pre-create a LLVM ThreadPool and a dialect registry and then explicitly pass those into a MLIRContext for each compilation. However, the C API does not expose the functions needed to follow this recommendation from a project that isn't calling MLIR's C++ dilectly. Add the necessary APIs to mlir-c, including a wrapper around LLVM's ThreadPool struct (so as to avoid having to amend or re-export parts of the LLVM API). Reviewed By: makslevental Differential Revision: https://reviews.llvm.org/D153593 --- mlir/include/mlir-c/IR.h | 18 ++++++++++++++++++ mlir/include/mlir-c/Support.h | 13 +++++++++++++ mlir/include/mlir/CAPI/Support.h | 5 +++++ mlir/lib/CAPI/IR/IR.cpp | 22 ++++++++++++++++++++++ mlir/lib/CAPI/IR/Support.cpp | 12 ++++++++++++ 5 files changed, 70 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 6b5d8cc4b..26f7f0738 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -84,8 +84,19 @@ typedef struct MlirNamedAttribute MlirNamedAttribute; //===----------------------------------------------------------------------===// /// Creates an MLIR context and transfers its ownership to the caller. +/// This sets the default multithreading option (enabled). MLIR_CAPI_EXPORTED MlirContext mlirContextCreate(void); +/// Creates an MLIR context with an explicit setting of the multithreading +/// setting and transfers its ownership to the caller. +MLIR_CAPI_EXPORTED MlirContext +mlirContextCreateWithThreading(bool threadingEnabled); + +/// Creates an MLIR context, setting the multithreading setting explicitly and +/// pre-loading the dialects from the provided DialectRegistry. +MLIR_CAPI_EXPORTED MlirContext mlirContextCreateWithRegistry( + MlirDialectRegistry registry, bool threadingEnabled); + /// Checks if two contexts are equal. MLIR_CAPI_EXPORTED bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2); @@ -144,6 +155,13 @@ mlirContextLoadAllAvailableDialects(MlirContext context); MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MlirStringRef name); +/// Sets the thread pool of the context explicitly, enabling multithreading in +/// the process. This API should be used to avoid re-creating thread pools in +/// long-running applications that perform multiple compilations, see +/// the C++ documentation for MLIRContext for details. +MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Support.h b/mlir/include/mlir-c/Support.h index 8d0188e31..78fc94f93 100644 --- a/mlir/include/mlir-c/Support.h +++ b/mlir/include/mlir-c/Support.h @@ -56,6 +56,8 @@ extern "C" { }; \ typedef struct name name +/// Re-export llvm::ThreadPool so as to avoid including the LLVM C API directly. +DEFINE_C_API_STRUCT(MlirLlvmThreadPool, void); DEFINE_C_API_STRUCT(MlirTypeID, const void); DEFINE_C_API_STRUCT(MlirTypeIDAllocator, void); @@ -138,6 +140,17 @@ inline static MlirLogicalResult mlirLogicalResultFailure(void) { return res; } +//===----------------------------------------------------------------------===// +// MlirLlvmThreadPool. +//===----------------------------------------------------------------------===// + +/// Create an LLVM thread pool. This is reexported here to avoid directly +/// pulling in the LLVM headers directly. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool mlirLlvmThreadPoolCreate(void); + +/// Destroy an LLVM thread pool. +MLIR_CAPI_EXPORTED void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool pool); + //===----------------------------------------------------------------------===// // TypeID API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index f3e8a67e0..82aa05185 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -21,6 +21,10 @@ #include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" +namespace llvm { +class ThreadPool; +} // namespace llvm + /// Converts a StringRef into its MLIR C API equivalent. inline MlirStringRef wrap(llvm::StringRef ref) { return mlirStringRefCreate(ref.data(), ref.size()); @@ -41,6 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) { return mlir::success(mlirLogicalResultIsSuccess(res)); } +DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool) DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 16b333afc..8c3ea09e9 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -39,6 +39,23 @@ MlirContext mlirContextCreate() { return wrap(context); } +static inline MLIRContext::Threading toThreadingEnum(bool threadingEnabled) { + return threadingEnabled ? MLIRContext::Threading::ENABLED + : MLIRContext::Threading::DISABLED; +} + +MlirContext mlirContextCreateWithThreading(bool threadingEnabled) { + auto *context = new MLIRContext(toThreadingEnum(threadingEnabled)); + return wrap(context); +} + +MlirContext mlirContextCreateWithRegistry(MlirDialectRegistry registry, + bool threadingEnabled) { + auto *context = + new MLIRContext(*unwrap(registry), toThreadingEnum(threadingEnabled)); + return wrap(context); +} + bool mlirContextEqual(MlirContext ctx1, MlirContext ctx2) { return unwrap(ctx1) == unwrap(ctx2); } @@ -84,6 +101,11 @@ void mlirContextLoadAllAvailableDialects(MlirContext context) { unwrap(context)->loadAllAvailableDialects(); } +void mlirContextSetThreadPool(MlirContext context, + MlirLlvmThreadPool threadPool) { + unwrap(context)->setThreadPool(*unwrap(threadPool)); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index ea081b2e9..81c9fc771 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -8,6 +8,7 @@ #include "mlir/CAPI/Support.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/ThreadPool.h" #include @@ -20,6 +21,17 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { llvm::StringRef(other.data, other.length); } +//===----------------------------------------------------------------------===// +// LLVM ThreadPool API. +//===----------------------------------------------------------------------===// +MlirLlvmThreadPool mlirLlvmThreadPoolCreate() { + return wrap(new llvm::ThreadPool()); +} + +void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) { + delete unwrap(threadPool); +} + //===----------------------------------------------------------------------===// // TypeID API. //===----------------------------------------------------------------------===// From 512b829a39fc99aaa8d299e36f7d048a2d6e5f89 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 4 Jul 2023 16:06:17 -0400 Subject: [PATCH 497/915] [mlir][python] Downcast attributes in more places Update remaining `PyAttribute`-returning APIs to return `MlirAttribute` instead, so that they go through the downcasting mechanism. Reviewed By: makslevental Differential Revision: https://reviews.llvm.org/D154462 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 22 +++++++++------------ mlir/lib/Bindings/Python/IRCore.cpp | 18 ++++++++--------- mlir/lib/Bindings/Python/IRModule.h | 6 +++--- mlir/lib/Bindings/Python/IRTypes.cpp | 24 +++++++++++++---------- 4 files changed, 34 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 4ee06fa7a..84a48a890 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -270,12 +270,11 @@ class PyArrayAttribute : public PyConcreteAttribute { PyArrayAttributeIterator &dunderIter() { return *this; } - PyAttribute dunderNext() { + MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) throw py::stop_iteration(); - return PyAttribute(attr.getContext(), - mlirArrayAttrGetElement(attr.get(), nextIndex++)); + return mlirArrayAttrGetElement(attr.get(), nextIndex++); } static void bind(py::module &m) { @@ -290,8 +289,8 @@ class PyArrayAttribute : public PyConcreteAttribute { int nextIndex = 0; }; - PyAttribute getItem(intptr_t i) { - return PyAttribute(getContext(), mlirArrayAttrGetElement(*this, i)); + MlirAttribute getItem(intptr_t i) { + return mlirArrayAttrGetElement(*this, i); } static void bindDerived(ClassTy &c) { @@ -843,13 +842,11 @@ class PyDenseElementsAttribute return mlirDenseElementsAttrIsSplat(self); }) .def("get_splat_value", - [](PyDenseElementsAttribute &self) -> PyAttribute { - if (!mlirDenseElementsAttrIsSplat(self)) { + [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) throw py::value_error( "get_splat_value called on a non-splat attribute"); - } - return PyAttribute(self.getContext(), - mlirDenseElementsAttrGetSplatValue(self)); + return mlirDenseElementsAttrGetSplatValue(self); }) .def_buffer(&PyDenseElementsAttribute::accessBuffer); } @@ -1018,10 +1015,9 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); - if (mlirAttributeIsNull(attr)) { + if (mlirAttributeIsNull(attr)) throw py::key_error("attempt to access a non-existent attribute"); - } - return PyAttribute(self.getContext(), attr); + return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3ab6d57b4..6c0b4a060 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1908,19 +1908,17 @@ void PySymbolTable::dunderDel(const std::string &name) { erase(py::cast(operation)); } -PyAttribute PySymbolTable::insert(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { operation->checkValid(); symbol.getOperation().checkValid(); MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute( - symbol.getOperation().getContext(), - mlirSymbolTableInsert(symbolTable, symbol.getOperation().get())); + return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } -PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { // Op must already be a symbol. PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -1929,7 +1927,7 @@ PyAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) throw py::value_error("Expected operation to have a symbol name."); - return PyAttribute(symbol.getOperation().getContext(), existingNameAttr); + return existingNameAttr; } void PySymbolTable::setSymbolName(PyOperationBase &symbol, @@ -1947,7 +1945,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); } -PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { +MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { PyOperation &operation = symbol.getOperation(); operation.checkValid(); MlirStringRef attrName = mlirSymbolTableGetVisibilityAttributeName(); @@ -1955,7 +1953,7 @@ PyAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) throw py::value_error("Expected operation to have a symbol visibility."); - return PyAttribute(symbol.getOperation().getContext(), existingVisAttr); + return existingVisAttr; } void PySymbolTable::setVisibility(PyOperationBase &symbol, @@ -2287,13 +2285,13 @@ class PyOpAttributeMap { PyOpAttributeMap(PyOperationRef operation) : operation(std::move(operation)) {} - PyAttribute dunderGetItemNamed(const std::string &name) { + MlirAttribute dunderGetItemNamed(const std::string &name) { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { throw py::key_error("attempt to access a non-existent attribute"); } - return PyAttribute(operation->getContext(), attr); + return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 225580f0f..76acfe5e7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -1174,14 +1174,14 @@ class PySymbolTable { /// Inserts the given operation into the symbol table. The operation must have /// the symbol trait. - PyAttribute insert(PyOperationBase &symbol); + MlirAttribute insert(PyOperationBase &symbol); /// Gets and sets the name of a symbol op. - static PyAttribute getSymbolName(PyOperationBase &symbol); + static MlirAttribute getSymbolName(PyOperationBase &symbol); static void setSymbolName(PyOperationBase &symbol, const std::string &name); /// Gets and sets the visibility of a symbol op. - static PyAttribute getVisibility(PyOperationBase &symbol); + static MlirAttribute getVisibility(PyOperationBase &symbol); static void setVisibility(PyOperationBase &symbol, const std::string &visibility); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index caf215be8..a7ccfbea5 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -505,11 +505,12 @@ class PyRankedTensorType py::arg("encoding") = py::none(), py::arg("loc") = py::none(), "Create a ranked tensor type"); c.def_property_readonly( - "encoding", [](PyRankedTensorType &self) -> std::optional { + "encoding", + [](PyRankedTensorType &self) -> std::optional { MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); if (mlirAttributeIsNull(encoding)) return std::nullopt; - return PyAttribute(self.getContext(), encoding); + return encoding; }); } }; @@ -570,9 +571,8 @@ class PyMemRefType : public PyConcreteType { py::arg("loc") = py::none(), "Create a memref type") .def_property_readonly( "layout", - [](PyMemRefType &self) -> PyAttribute { - MlirAttribute layout = mlirMemRefTypeGetLayout(self); - return PyAttribute(self.getContext(), layout); + [](PyMemRefType &self) -> MlirAttribute { + return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") .def_property_readonly( @@ -584,9 +584,11 @@ class PyMemRefType : public PyConcreteType { "The layout of the MemRef type as an affine map.") .def_property_readonly( "memory_space", - [](PyMemRefType &self) -> PyAttribute { + [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; }, "Returns the memory space of the given MemRef type."); } @@ -622,9 +624,11 @@ class PyUnrankedMemRefType py::arg("loc") = py::none(), "Create a unranked memref type") .def_property_readonly( "memory_space", - [](PyUnrankedMemRefType &self) -> PyAttribute { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - return PyAttribute(self.getContext(), a); + [](PyUnrankedMemRefType &self) -> std::optional { + MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return a; }, "Returns the memory space of the given Unranked MemRef type."); } From b0af1ebc0feae2d99437493a19d5e4451c0967da Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 4 Jul 2023 22:05:36 -0400 Subject: [PATCH 498/915] [mlir][python] Replace PythonAttr mappings with downcasting Since op `Attribute`s are automatically downcasted on access, these mappings aren't necessary anymore. Instead we just always generate the getters/setters for attributes even if there isn't a `PythonAttr` mapping. depends on D154462 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D154468 --- mlir/include/mlir/Bindings/Python/Attributes.td | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/Attributes.td b/mlir/include/mlir/Bindings/Python/Attributes.td index f9a7fa703..c5947c62d 100644 --- a/mlir/include/mlir/Bindings/Python/Attributes.td +++ b/mlir/include/mlir/Bindings/Python/Attributes.td @@ -21,14 +21,4 @@ class PythonAttr { string pythonType = p; } -// Mappings between supported builtin attribtues and Python types. -def : PythonAttr<"::mlir::Attribute", "_ods_ir.Attribute">; -def : PythonAttr<"::mlir::BoolAttr", "_ods_ir.BoolAttr">; -def : PythonAttr<"::mlir::IntegerAttr", "_ods_ir.IntegerAttr">; -def : PythonAttr<"::mlir::FloatAttr", "_ods_ir.FloatAttr">; -def : PythonAttr<"::mlir::StringAttr", "_ods_ir.StringAttr">; -def : PythonAttr<"::mlir::DenseElementsAttr", "_ods_ir.DenseElementsAttr">; -def : PythonAttr<"::mlir::DenseIntElementsAttr", "_ods_ir.DenseIntElementsAttr">; -def : PythonAttr<"::mlir::DenseFPElementsAttr", "_ods_ir.DenseFPElementsAttr">; - #endif From 0de9a8292d04b581ad21b9c1362baaa053cf91a0 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Wed, 12 Jul 2023 21:58:51 +0000 Subject: [PATCH 499/915] [mlir][sparse] introduce new 2:4 block sparsity level type. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D155128 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 27 ++++++++++--------- .../Bindings/Python/DialectSparseTensor.cpp | 1 + 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 0ad1a315e..b2e4b96c6 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,19 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b0001_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b0010_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b0010_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b0010_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b0010_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b0100_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b0100_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b0100_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b0100_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b1000_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b1000_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b1000_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b1000_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b01000_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b01000_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b01000_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b01000_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 2e8d53545..d03088341 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -19,6 +19,7 @@ using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "DimLevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) + .value("compressed24", MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR) .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) .value("compressed-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) .value("compressed-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) From 950dcad21d40507f8961ff952cd5ae97f2dd42cc Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 12 Jul 2023 22:09:49 -0700 Subject: [PATCH 500/915] [MLIR] Add a number of methods to the C API Those include: - mlirFuncSetArgAttr - mlirOperationSetOperands - mlirRegionTakeBody - mlirBlockInsertArgument Reviewed By: ftynse, jpienaar Differential Revision: https://reviews.llvm.org/D155091 --- mlir/include/mlir-c/Dialect/Func.h | 9 +++++++++ mlir/include/mlir-c/IR.h | 16 ++++++++++++++++ mlir/lib/CAPI/Dialect/Func.cpp | 8 ++++++++ mlir/lib/CAPI/IR/IR.cpp | 15 +++++++++++++++ 4 files changed, 48 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Func.h b/mlir/include/mlir-c/Dialect/Func.h index eeb6dfe05..1df759f0e 100644 --- a/mlir/include/mlir-c/Dialect/Func.h +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -18,7 +18,10 @@ #ifndef MLIR_C_DIALECT_FUNC_H #define MLIR_C_DIALECT_FUNC_H +#include + #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -26,6 +29,12 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Func, func); +/// Sets the argument attribute 'name' of an argument at index 'pos'. +/// Asserts that the operation is a FuncOp. +MLIR_CAPI_EXPORTED void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, + MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 26f7f0738..5312db091 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -533,6 +533,11 @@ MLIR_CAPI_EXPORTED MlirValue mlirOperationGetOperand(MlirOperation op, MLIR_CAPI_EXPORTED void mlirOperationSetOperand(MlirOperation op, intptr_t pos, MlirValue newValue); +/// Replaces the operands of the operation. +MLIR_CAPI_EXPORTED void mlirOperationSetOperands(MlirOperation op, + intptr_t nOperands, + MlirValue const *operands); + /// Returns the number of results of the operation. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumResults(MlirOperation op); @@ -664,6 +669,10 @@ MLIR_CAPI_EXPORTED MlirRegion mlirOperationGetFirstRegion(MlirOperation op); /// operation. MLIR_CAPI_EXPORTED MlirRegion mlirRegionGetNextInOperation(MlirRegion region); +/// Moves the entire content of the source region to the target region. +MLIR_CAPI_EXPORTED void mlirRegionTakeBody(MlirRegion target, + MlirRegion source); + //===----------------------------------------------------------------------===// // Block API. //===----------------------------------------------------------------------===// @@ -737,6 +746,13 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc); +/// Inserts an argument of the specified type at a specified index to the block. +/// Returns the newly added argument. +MLIR_CAPI_EXPORTED MlirValue mlirBlockInsertArgument(MlirBlock block, + intptr_t pos, + MlirType type, + MlirLocation loc); + /// Returns `pos`-th argument of the block. MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos); diff --git a/mlir/lib/CAPI/Dialect/Func.cpp b/mlir/lib/CAPI/Dialect/Func.cpp index a49d2f425..942e090fd 100644 --- a/mlir/lib/CAPI/Dialect/Func.cpp +++ b/mlir/lib/CAPI/Dialect/Func.cpp @@ -7,7 +7,15 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Func.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Func/IR/FuncOps.h" MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Func, func, mlir::func::FuncDialect) + +void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, MlirStringRef name, + MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setArgAttr(pos, unwrap(name), unwrap(attr)); +} diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8c3ea09e9..dedae3ddd 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -497,6 +497,12 @@ void mlirOperationSetOperand(MlirOperation op, intptr_t pos, unwrap(op)->setOperand(static_cast(pos), unwrap(newValue)); } +void mlirOperationSetOperands(MlirOperation op, intptr_t nOperands, + MlirValue const *operands) { + SmallVector ops; + unwrap(op)->setOperands(unwrapList(nOperands, operands, ops)); +} + intptr_t mlirOperationGetNumResults(MlirOperation op) { return static_cast(unwrap(op)->getNumResults()); } @@ -632,6 +638,10 @@ void mlirRegionDestroy(MlirRegion region) { delete static_cast(region.ptr); } +void mlirRegionTakeBody(MlirRegion target, MlirRegion source) { + unwrap(target)->takeBody(*unwrap(source)); +} + //===----------------------------------------------------------------------===// // Block API. //===----------------------------------------------------------------------===// @@ -730,6 +740,11 @@ MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); } +MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, + MlirLocation loc) { + return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc))); +} + MlirValue mlirBlockGetArgument(MlirBlock block, intptr_t pos) { return wrap(unwrap(block)->getArgument(static_cast(pos))); } From add4395803ebf9823ddc9f4c65bb682764c8ebef Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Wed, 12 Jul 2023 22:11:14 -0700 Subject: [PATCH 501/915] [MLIR][Python] Implement pybind adapters for MlirBlock Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D155092 --- mlir/include/mlir-c/Bindings/Python/Interop.h | 18 ++++++++++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 11 +++++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/Bindings/Python/IRModule.h | 3 +++ 4 files changed, 41 insertions(+) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 33332d6a3..f79c10cb9 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -62,6 +62,7 @@ MAKE_MLIR_PYTHON_QUALNAME("ir.AffineMap._CAPIPtr") #define MLIR_PYTHON_CAPSULE_ATTRIBUTE \ MAKE_MLIR_PYTHON_QUALNAME("ir.Attribute._CAPIPtr") +#define MLIR_PYTHON_CAPSULE_BLOCK MAKE_MLIR_PYTHON_QUALNAME("ir.Block._CAPIPtr") #define MLIR_PYTHON_CAPSULE_CONTEXT \ MAKE_MLIR_PYTHON_QUALNAME("ir.Context._CAPIPtr") #define MLIR_PYTHON_CAPSULE_DIALECT_REGISTRY \ @@ -175,6 +176,23 @@ static inline MlirAttribute mlirPythonCapsuleToAttribute(PyObject *capsule) { return attr; } +/** Creates a capsule object encapsulating the raw C-API MlirBlock. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject *mlirPythonBlockToCapsule(MlirBlock block) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(block), + MLIR_PYTHON_CAPSULE_BLOCK, NULL); +} + +/** Extracts an MlirBlock from a capsule as produced from + * mlirPythonBlockToCapsule. If the capsule is not of the right type, then + * a null pass manager is returned (as checked via mlirBlockIsNull). */ +static inline MlirBlock mlirPythonCapsuleToBlock(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_BLOCK); + MlirBlock block = {ptr}; + return block; +} + /** Creates a capsule object encapsulating the raw C-API MlirContext. * The returned capsule does not extend or affect ownership of any Python * objects that reference the context in any way. diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 44a10d619..49680c8b7 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -102,6 +102,17 @@ struct type_caster { } }; +/// Casts object -> MlirBlock. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirBlock, _("MlirBlock")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToBlock(capsule.ptr()); + return !mlirBlockIsNull(value); + } +}; + /// Casts object -> MlirContext. template <> struct type_caster { diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 6c0b4a060..39049f387 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -257,6 +257,14 @@ struct PyAttrBuilderMap { } }; +//------------------------------------------------------------------------------ +// PyBlock +//------------------------------------------------------------------------------ + +py::object PyBlock::getCapsule() { + return py::reinterpret_steal(mlirPythonBlockToCapsule(get())); +} + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -2968,6 +2976,7 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyBlock. //---------------------------------------------------------------------------- py::class_(m, "Block", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) .def_property_readonly( "owner", [](PyBlock &self) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 76acfe5e7..5da1d7d25 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -763,6 +763,9 @@ class PyBlock { void checkValid() { return parentOperation->checkValid(); } + /// Gets a capsule wrapping the void* within the MlirBlock. + pybind11::object getCapsule(); + private: PyOperationRef parentOperation; MlirBlock block; From 46cb6c8d5ee1f557581c20ca08f9d1e472bf0be9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 13 Jul 2023 14:06:59 +0200 Subject: [PATCH 502/915] [mlir][Linalg] Fold/erase self-copy linalg.copy on buffers Differential Revision: https://reviews.llvm.org/D155203 --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index e4512cd1e..08818b212 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -17,6 +17,7 @@ def copy( Numeric casting is performed on the input operand, promoting it to the same data type as the accumulator/output. """ + defines(Canonicalizer) O[None] = cast(U, I[None]) From a11388129f4e21b86ac67c744c6d6ecb5bf4fb4e Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 14 Jul 2023 16:08:15 -0700 Subject: [PATCH 503/915] [MLIR:Python] Make DenseElementsAttr.get() only request a buffer format if no explicit type was provided. Not every NumPy type (e.g., the `ml_dtypes.bfloat16` NumPy extension type) has a type in the Python buffer protocol, so exporting such a buffer with `PyBUF_FORMAT` may fail. However, we don't care about the self-reported type of a buffer if the user provides an explicit type. In the case that an explicit type is provided, don't request the format from the buffer protocol, which allows arrays whose element types are unknown to the buffer protocol to be passed. Reviewed By: jpienaar, ftynse Differential Revision: https://reviews.llvm.org/D155209 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 175 ++++++++++++---------- 1 file changed, 93 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 84a48a890..75d743f3a 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -7,12 +7,15 @@ //===----------------------------------------------------------------------===// #include +#include #include #include "IRModule.h" #include "PybindUtils.h" +#include "llvm/ADT/ScopeExit.h" + #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir/Bindings/Python/PybindAdaptors.h" @@ -612,19 +615,20 @@ class PyDenseElementsAttribute std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { // Request a contiguous view. In exotic cases, this will cause a copy. - int flags = PyBUF_C_CONTIGUOUS | PyBUF_FORMAT; - Py_buffer *view = new Py_buffer(); - if (PyObject_GetBuffer(array.ptr(), view, flags) != 0) { - delete view; + int flags = PyBUF_ND; + if (!explicitType) { + flags |= PyBUF_FORMAT; + } + Py_buffer view; + if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { throw py::error_already_set(); } - py::buffer_info arrayInfo(view); + auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); SmallVector shape; if (explicitShape) { shape.append(explicitShape->begin(), explicitShape->end()); } else { - shape.append(arrayInfo.shape.begin(), - arrayInfo.shape.begin() + arrayInfo.ndim); + shape.append(view.shape, view.shape + view.ndim); } MlirAttribute encodingAttr = mlirAttributeGetNull(); @@ -638,85 +642,92 @@ class PyDenseElementsAttribute std::optional bulkLoadElementType; if (explicitType) { bulkLoadElementType = *explicitType; - } else if (arrayInfo.format == "f") { - // f32 - assert(arrayInfo.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (arrayInfo.format == "d") { - // f64 - assert(arrayInfo.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (arrayInfo.format == "e") { - // f16 - assert(arrayInfo.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (isSignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // i32 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (arrayInfo.itemsize == 8) { - // i64 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (arrayInfo.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (arrayInfo.itemsize == 2) { - // i16 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(arrayInfo.format)) { - if (arrayInfo.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (arrayInfo.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (arrayInfo.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (arrayInfo.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (bulkLoadElementType) { - MlirType shapedType; - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); } - shapedType = *bulkLoadElementType; - } else { - shapedType = mlirRankedTensorTypeGet( - shape.size(), shape.data(), *bulkLoadElementType, encodingAttr); } - size_t rawBufferSize = arrayInfo.size * arrayInfo.itemsize; - MlirAttribute attr = mlirDenseElementsAttrRawBufferGet( - shapedType, rawBufferSize, arrayInfo.ptr); - if (mlirAttributeIsNull(attr)) { + if (!bulkLoadElementType) { throw std::invalid_argument( - "DenseElementsAttr could not be constructed from the given buffer. " - "This may mean that the Python buffer layout does not match that " - "MLIR expected layout and is a bug."); + std::string("unimplemented array format conversion from format: ") + + std::string(format)); } - return PyDenseElementsAttribute(contextWrapper->getRef(), attr); } - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - arrayInfo.format); + MlirType shapedType; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + shapedType = *bulkLoadElementType; + } else { + shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + size_t rawBufferSize = view.len; + MlirAttribute attr = + mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseElementsAttr could not be constructed from the given buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + return PyDenseElementsAttribute(contextWrapper->getRef(), attr); } static PyDenseElementsAttribute getSplat(const PyType &shapedType, @@ -852,7 +863,7 @@ class PyDenseElementsAttribute } private: - static bool isUnsignedIntegerFormat(const std::string &format) { + static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; char code = format[0]; @@ -860,7 +871,7 @@ class PyDenseElementsAttribute code == 'Q'; } - static bool isSignedIntegerFormat(const std::string &format) { + static bool isSignedIntegerFormat(std::string_view format) { if (format.empty()) return false; char code = format[0]; From 0dc71dee78cb1fe2c7de7d98dd518405646d597f Mon Sep 17 00:00:00 2001 From: Adam Paszke Date: Mon, 17 Jul 2023 14:32:31 +0000 Subject: [PATCH 504/915] [MLIR][CAPI] Add C API dialect registration methods for Arith, Math, MemRef and Vector dialects Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155450 --- mlir/include/mlir-c/Dialect/Arith.h | 33 +++++++++++++++++++++++++ mlir/include/mlir-c/Dialect/Math.h | 33 +++++++++++++++++++++++++ mlir/include/mlir-c/Dialect/MemRef.h | 33 +++++++++++++++++++++++++ mlir/include/mlir-c/Dialect/Vector.h | 33 +++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/Arith.cpp | 13 ++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 36 ++++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/Math.cpp | 13 ++++++++++ mlir/lib/CAPI/Dialect/MemRef.cpp | 14 +++++++++++ mlir/lib/CAPI/Dialect/Vector.cpp | 14 +++++++++++ 9 files changed, 222 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/Arith.h create mode 100644 mlir/include/mlir-c/Dialect/Math.h create mode 100644 mlir/include/mlir-c/Dialect/MemRef.h create mode 100644 mlir/include/mlir-c/Dialect/Vector.h create mode 100644 mlir/lib/CAPI/Dialect/Arith.cpp create mode 100644 mlir/lib/CAPI/Dialect/Math.cpp create mode 100644 mlir/lib/CAPI/Dialect/MemRef.cpp create mode 100644 mlir/lib/CAPI/Dialect/Vector.cpp diff --git a/mlir/include/mlir-c/Dialect/Arith.h b/mlir/include/mlir-c/Dialect/Arith.h new file mode 100644 index 000000000..41e7cb2b3 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Arith.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Arith.h - C API for Arith dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Arith dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_ARITH_H +#define MLIR_C_DIALECT_ARITH_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Arith, arith); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_ARITH_H diff --git a/mlir/include/mlir-c/Dialect/Math.h b/mlir/include/mlir-c/Dialect/Math.h new file mode 100644 index 000000000..5269e1a1b --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Math.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Math.h - C API for Math dialect ------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Math dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MATH_H +#define MLIR_C_DIALECT_MATH_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Math, math); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MATH_H diff --git a/mlir/include/mlir-c/Dialect/MemRef.h b/mlir/include/mlir-c/Dialect/MemRef.h new file mode 100644 index 000000000..087a4b3f8 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/MemRef.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/MemRef.h - C API for MemRef dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// MemRef dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_MEMREF_H +#define MLIR_C_DIALECT_MEMREF_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(MemRef, memref); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_MEMREF_H diff --git a/mlir/include/mlir-c/Dialect/Vector.h b/mlir/include/mlir-c/Dialect/Vector.h new file mode 100644 index 000000000..6256c82d1 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Vector.h @@ -0,0 +1,33 @@ +//===-- mlir-c/Dialect/Vector.h - C API for Vector dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for registering and accessing the +// Vector dialect. A dialect should be registered with a context to make it +// available to users of the context. These users must load the dialect +// before using any of its attributes, operations or types. Parser and pass +// manager can load registered dialects automatically. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_VECTOR_H +#define MLIR_C_DIALECT_VECTOR_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Vector, vector); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_VECTOR_H diff --git a/mlir/lib/CAPI/Dialect/Arith.cpp b/mlir/lib/CAPI/Dialect/Arith.cpp new file mode 100644 index 000000000..993f77e55 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Arith.cpp @@ -0,0 +1,13 @@ +//===- Arith.cpp - C Interface for Arith dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Arith.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Arith, arith, mlir::arith::ArithDialect) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 6c8454a79..4b4ab74e6 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,3 +1,12 @@ +add_mlir_upstream_c_api_library(MLIRCAPIArith + Arith.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRArithDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIAsync Async.cpp AsyncPasses.cpp @@ -22,6 +31,24 @@ add_mlir_upstream_c_api_library(MLIRCAPIControlFlow MLIRControlFlowDialect ) +add_mlir_upstream_c_api_library(MLIRCAPIMath + Math.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMathDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIMemRef + MemRef.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRMemRefDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIGPU GPU.cpp GPUPasses.cpp @@ -142,3 +169,12 @@ add_mlir_upstream_c_api_library(MLIRCAPIPDL MLIRCAPIIR MLIRPDLDialect ) + +add_mlir_upstream_c_api_library(MLIRCAPIVector + Vector.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRVectorDialect +) diff --git a/mlir/lib/CAPI/Dialect/Math.cpp b/mlir/lib/CAPI/Dialect/Math.cpp new file mode 100644 index 000000000..483e549a3 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Math.cpp @@ -0,0 +1,13 @@ +//===- Math.cpp - C Interface for Math dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Math.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Math/IR/Math.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Math, math, mlir::math::MathDialect) diff --git a/mlir/lib/CAPI/Dialect/MemRef.cpp b/mlir/lib/CAPI/Dialect/MemRef.cpp new file mode 100644 index 000000000..cfcdea974 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/MemRef.cpp @@ -0,0 +1,14 @@ +//===- MemRef.cpp - C Interface for MemRef dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/MemRef.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(MemRef, memref, + mlir::memref::MemRefDialect) diff --git a/mlir/lib/CAPI/Dialect/Vector.cpp b/mlir/lib/CAPI/Dialect/Vector.cpp new file mode 100644 index 000000000..c744b83b6 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Vector.cpp @@ -0,0 +1,14 @@ +//===- Vector.cpp - C Interface for Vector dialect ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Vector.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Vector, vector, + mlir::vector::VectorDialect) From 9bfe0a7acb2ff85741ae458acdb94014b09b384b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 18 Jul 2023 08:15:53 +0000 Subject: [PATCH 505/915] [mlir][transform][bufferization][python] Add .td file for bindings. Reviewed By: springerm, ftynse Differential Revision: https://reviews.llvm.org/D155564 --- mlir/python/CMakeLists.txt | 9 ++++++++ .../dialects/BufferizationTransformOps.td | 21 +++++++++++++++++++ .../mlir/dialects/transform/bufferization.py | 5 +++++ 3 files changed, 35 insertions(+) create mode 100644 mlir/python/mlir/dialects/BufferizationTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/bufferization.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 39dd7b006..29152b5c5 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings( _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/BufferizationTransformOps.td + SOURCES + dialects/transform/bufferization.py + DIALECT_NAME transform + EXTENSION_NAME bufferization_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/BufferizationTransformOps.td b/mlir/python/mlir/dialects/BufferizationTransformOps.td new file mode 100644 index 000000000..cf2ed661f --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationTransformOps.td @@ -0,0 +1,21 @@ +//===-- BufferizationTransformOps.td -----------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the bufferization dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS +#define PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td" + +#endif // PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py new file mode 100644 index 000000000..eb77b746c --- /dev/null +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._bufferization_transform_ops_gen import * From c98b38818c0bfc096fd79198605339713e0f1dbd Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Tue, 4 Jul 2023 22:21:26 -0400 Subject: [PATCH 506/915] [mlir][python] Remove PythonAttr mapping functionality This functionality has been replaced by TypeCasters (see D151840) depends on D154468 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D154469 --- .../mlir/Bindings/Python/Attributes.td | 24 ------------------- mlir/python/mlir/dialects/ArithOps.td | 1 - mlir/python/mlir/dialects/AsyncOps.td | 1 - mlir/python/mlir/dialects/BufferizationOps.td | 1 - .../dialects/BufferizationTransformOps.td | 1 - mlir/python/mlir/dialects/BuiltinOps.td | 1 - mlir/python/mlir/dialects/ComplexOps.td | 1 - mlir/python/mlir/dialects/ControlFlowOps.td | 1 - mlir/python/mlir/dialects/FuncOps.td | 1 - mlir/python/mlir/dialects/GPUOps.td | 1 - mlir/python/mlir/dialects/LinalgOps.td | 1 - .../dialects/LinalgStructuredTransformOps.td | 1 - mlir/python/mlir/dialects/MLProgramOps.td | 1 - mlir/python/mlir/dialects/MathOps.td | 1 - mlir/python/mlir/dialects/MemRefOps.td | 1 - mlir/python/mlir/dialects/PDLOps.td | 1 - .../mlir/dialects/SCFLoopTransformOps.td | 1 - mlir/python/mlir/dialects/SCFOps.td | 1 - mlir/python/mlir/dialects/ShapeOps.td | 1 - mlir/python/mlir/dialects/SparseTensorOps.td | 1 - mlir/python/mlir/dialects/TensorOps.td | 1 - mlir/python/mlir/dialects/TosaOps.td | 1 - mlir/python/mlir/dialects/TransformOps.td | 1 - .../mlir/dialects/TransformPDLExtensionOps.td | 1 - mlir/python/mlir/dialects/VectorOps.td | 1 - 25 files changed, 48 deletions(-) delete mode 100644 mlir/include/mlir/Bindings/Python/Attributes.td diff --git a/mlir/include/mlir/Bindings/Python/Attributes.td b/mlir/include/mlir/Bindings/Python/Attributes.td deleted file mode 100644 index c5947c62d..000000000 --- a/mlir/include/mlir/Bindings/Python/Attributes.td +++ /dev/null @@ -1,24 +0,0 @@ -//===-- Attributes.td - Attribute mapping for Python -------*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This defines the mapping between MLIR ODS attributes and the corresponding -// Python binding classes. -// -//===----------------------------------------------------------------------===// - -#ifndef PYTHON_BINDINGS_ATTRIBUTES -#define PYTHON_BINDINGS_ATTRIBUTES - -// A mapping between the attribute storage type and the corresponding Python -// type. There is not necessarily a 1-1 match for non-builtin attributes. -class PythonAttr { - string cppStorageType = c; - string pythonType = p; -} - -#endif diff --git a/mlir/python/mlir/dialects/ArithOps.td b/mlir/python/mlir/dialects/ArithOps.td index aaa9fad21..60dbb08a0 100644 --- a/mlir/python/mlir/dialects/ArithOps.td +++ b/mlir/python/mlir/dialects/ArithOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_ARITH_OPS #define PYTHON_BINDINGS_ARITH_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Arith/IR/ArithOps.td" #endif diff --git a/mlir/python/mlir/dialects/AsyncOps.td b/mlir/python/mlir/dialects/AsyncOps.td index b65b9bafd..2b05045cf 100644 --- a/mlir/python/mlir/dialects/AsyncOps.td +++ b/mlir/python/mlir/dialects/AsyncOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_ASYNC_OPS #define PYTHON_BINDINGS_ASYNC_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Async/IR/AsyncOps.td" #endif diff --git a/mlir/python/mlir/dialects/BufferizationOps.td b/mlir/python/mlir/dialects/BufferizationOps.td index c5170cee3..b2ac7e281 100644 --- a/mlir/python/mlir/dialects/BufferizationOps.td +++ b/mlir/python/mlir/dialects/BufferizationOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_BUFFERIZATION_OPS #define PYTHON_BINDINGS_BUFFERIZATION_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Bufferization/IR/BufferizationOps.td" #endif diff --git a/mlir/python/mlir/dialects/BufferizationTransformOps.td b/mlir/python/mlir/dialects/BufferizationTransformOps.td index cf2ed661f..34213be22 100644 --- a/mlir/python/mlir/dialects/BufferizationTransformOps.td +++ b/mlir/python/mlir/dialects/BufferizationTransformOps.td @@ -15,7 +15,6 @@ #ifndef PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS #define PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td" #endif // PYTHON_BINDINGS_BUFFERIZATION_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/BuiltinOps.td b/mlir/python/mlir/dialects/BuiltinOps.td index ecbb8227d..d1c595283 100644 --- a/mlir/python/mlir/dialects/BuiltinOps.td +++ b/mlir/python/mlir/dialects/BuiltinOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_BUILTIN_OPS #define PYTHON_BINDINGS_BUILTIN_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/IR/BuiltinOps.td" #endif diff --git a/mlir/python/mlir/dialects/ComplexOps.td b/mlir/python/mlir/dialects/ComplexOps.td index 6fd846ba6..17825b6be 100644 --- a/mlir/python/mlir/dialects/ComplexOps.td +++ b/mlir/python/mlir/dialects/ComplexOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_COMPLEX_OPS #define PYTHON_BINDINGS_COMPLEX_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Complex/IR/ComplexOps.td" #endif diff --git a/mlir/python/mlir/dialects/ControlFlowOps.td b/mlir/python/mlir/dialects/ControlFlowOps.td index 1bb4d41f2..c9610a3c6 100644 --- a/mlir/python/mlir/dialects/ControlFlowOps.td +++ b/mlir/python/mlir/dialects/ControlFlowOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_CONTROL_FLOW_OPS #define PYTHON_BINDINGS_CONTROL_FLOW_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.td" #endif diff --git a/mlir/python/mlir/dialects/FuncOps.td b/mlir/python/mlir/dialects/FuncOps.td index 1728091f4..0816d6a3f 100644 --- a/mlir/python/mlir/dialects/FuncOps.td +++ b/mlir/python/mlir/dialects/FuncOps.td @@ -14,7 +14,6 @@ #ifndef PYTHON_BINDINGS_FUNC #define PYTHON_BINDINGS_FUNC -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Func/IR/FuncOps.td" #endif diff --git a/mlir/python/mlir/dialects/GPUOps.td b/mlir/python/mlir/dialects/GPUOps.td index 4e23d322f..83b1f6cd4 100644 --- a/mlir/python/mlir/dialects/GPUOps.td +++ b/mlir/python/mlir/dialects/GPUOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_GPU_OPS #define PYTHON_BINDINGS_GPU_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/GPU/IR/GPUOps.td" #endif diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td index 7650e954d..b7658c85a 100644 --- a/mlir/python/mlir/dialects/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_LINALG_OPS #define PYTHON_BINDINGS_LINALG_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td index a9a53fe6d..e11065bf8 100644 --- a/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformOps.td @@ -15,7 +15,6 @@ #ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS #define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td" #endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/MLProgramOps.td b/mlir/python/mlir/dialects/MLProgramOps.td index 5ac45ca1b..35b348d5f 100644 --- a/mlir/python/mlir/dialects/MLProgramOps.td +++ b/mlir/python/mlir/dialects/MLProgramOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_MLPROGRAM_OPS #define PYTHON_BINDINGS_MLPROGRAM_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/MLProgram/IR/MLProgramOps.td" #endif diff --git a/mlir/python/mlir/dialects/MathOps.td b/mlir/python/mlir/dialects/MathOps.td index 03d1fdef0..8f68467ea 100644 --- a/mlir/python/mlir/dialects/MathOps.td +++ b/mlir/python/mlir/dialects/MathOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_MATH_OPS #define PYTHON_BINDINGS_MATH_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Math/IR/MathOps.td" #endif diff --git a/mlir/python/mlir/dialects/MemRefOps.td b/mlir/python/mlir/dialects/MemRefOps.td index 8dd976479..ed346d5a2 100644 --- a/mlir/python/mlir/dialects/MemRefOps.td +++ b/mlir/python/mlir/dialects/MemRefOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_MEMREF_OPS #define PYTHON_BINDINGS_MEMREF_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/MemRef/IR/MemRefOps.td" #endif diff --git a/mlir/python/mlir/dialects/PDLOps.td b/mlir/python/mlir/dialects/PDLOps.td index e4e6a83cd..a8c2d6bdb 100644 --- a/mlir/python/mlir/dialects/PDLOps.td +++ b/mlir/python/mlir/dialects/PDLOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_PDL_OPS #define PYTHON_BINDINGS_PDL_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/PDL/IR/PDLOps.td" #endif diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td index 5ef07fc7a..7b09fc14b 100644 --- a/mlir/python/mlir/dialects/SCFLoopTransformOps.td +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -15,7 +15,6 @@ #ifndef PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS #define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" #endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/SCFOps.td b/mlir/python/mlir/dialects/SCFOps.td index 58f337e23..f1fc8a8db 100644 --- a/mlir/python/mlir/dialects/SCFOps.td +++ b/mlir/python/mlir/dialects/SCFOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_SCF_OPS #define PYTHON_BINDINGS_SCF_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/SCF/IR/SCFOps.td" #endif diff --git a/mlir/python/mlir/dialects/ShapeOps.td b/mlir/python/mlir/dialects/ShapeOps.td index c469a586b..e217b2edc 100644 --- a/mlir/python/mlir/dialects/ShapeOps.td +++ b/mlir/python/mlir/dialects/ShapeOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_SHAPE_OPS #define PYTHON_BINDINGS_SHAPE_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Shape/IR/ShapeOps.td" #endif diff --git a/mlir/python/mlir/dialects/SparseTensorOps.td b/mlir/python/mlir/dialects/SparseTensorOps.td index b3b4846db..3f0d522f3 100644 --- a/mlir/python/mlir/dialects/SparseTensorOps.td +++ b/mlir/python/mlir/dialects/SparseTensorOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_SPARSE_TENSOR_OPS #define PYTHON_BINDINGS_SPARSE_TENSOR_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/SparseTensor/IR/SparseTensorOps.td" #endif diff --git a/mlir/python/mlir/dialects/TensorOps.td b/mlir/python/mlir/dialects/TensorOps.td index 40ecea7bf..d68cd2447 100644 --- a/mlir/python/mlir/dialects/TensorOps.td +++ b/mlir/python/mlir/dialects/TensorOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_TENSOR_OPS #define PYTHON_BINDINGS_TENSOR_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" #endif diff --git a/mlir/python/mlir/dialects/TosaOps.td b/mlir/python/mlir/dialects/TosaOps.td index d906bad7c..b429780bc 100644 --- a/mlir/python/mlir/dialects/TosaOps.td +++ b/mlir/python/mlir/dialects/TosaOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_TOSA_OPS #define PYTHON_BINDINGS_TOSA_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Tosa/IR/TosaOps.td" #endif diff --git a/mlir/python/mlir/dialects/TransformOps.td b/mlir/python/mlir/dialects/TransformOps.td index 7f0d80ead..e2f6cf932 100644 --- a/mlir/python/mlir/dialects/TransformOps.td +++ b/mlir/python/mlir/dialects/TransformOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_TRANSFORM_OPS #define PYTHON_BINDINGS_TRANSFORM_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Transform/IR/TransformOps.td" #endif // PYTHON_BINDINGS_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td index e3e5daf18..56fadd029 100644 --- a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td +++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td @@ -14,7 +14,6 @@ #ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS #define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td" #endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/VectorOps.td b/mlir/python/mlir/dialects/VectorOps.td index 267c2b2a0..69a1028c9 100644 --- a/mlir/python/mlir/dialects/VectorOps.td +++ b/mlir/python/mlir/dialects/VectorOps.td @@ -9,7 +9,6 @@ #ifndef PYTHON_BINDINGS_VECTOR_OPS #define PYTHON_BINDINGS_VECTOR_OPS -include "mlir/Bindings/Python/Attributes.td" include "mlir/Dialect/Vector/IR/VectorOps.td" #endif From 0ec71a14ed038cb85095840d7508efa43f65a8cb Mon Sep 17 00:00:00 2001 From: Jack Wolfard Date: Sat, 15 Jul 2023 02:44:42 -0700 Subject: [PATCH 507/915] [mlir][python] Add install target for MLIR Python sources. Differential Revision: https://reviews.llvm.org/D155362 --- mlir/python/CMakeLists.txt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 29152b5c5..22a55dbb6 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -517,6 +517,19 @@ add_mlir_python_common_capi_library(MLIRPythonCAPI ${_ADDL_TEST_SOURCES} ) +################################################################################ +# Custom targets. +################################################################################ + +_flatten_mlir_python_targets(mlir_python_sources_deps MLIRPythonSources) +add_custom_target("mlir-python-sources" DEPENDS ${mlir_python_sources_deps}) +if(NOT LLVM_ENABLE_IDE) + add_llvm_install_targets(install-mlir-python-sources + DEPENDS mlir-python-sources + COMPONENT mlir-python-sources + ) +endif() + ################################################################################ # The fully assembled package of modules. # This must come last. From b35c9476fc89c7bf28797a0898ad337140472168 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 18 Jul 2023 09:07:17 +0000 Subject: [PATCH 508/915] [mlir][linalg][transform][python] Add type arg to MatchOp extension. The extension class to MatchOp has a class method called match_op_names. The previous version of that function did not allow to specify the result type. This, however, may be useful/necessary if the op consuming the resulting handle requires a particular type (such as the bufferization.EmptyTensorToAllocTensorOp). This patch adds an overload to match_op_names that allows to specify the result type. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155567 --- .../dialects/_structured_transform_ops_ext.py | 43 +++++++++++++++++-- 1 file changed, 39 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index b754034c8..640730997 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -85,17 +85,52 @@ def __init__( class MatchOp: """Specialization for MatchOp class.""" + @overload @classmethod def match_op_names( - MatchOp, + cls, target: Union[Operation, Value], names: Sequence[str], + *, loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() - return MatchOp( - pdl_operation_type, + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Sequence[str], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str]], + names_or_none: Optional[Sequence[str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + return cls( + result_type, _get_op_result_or_value(target), ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, From 8e51a25e7aac087a4f7b78c6546a2a16a260ee8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 12 Jul 2023 10:38:24 +0000 Subject: [PATCH 509/915] [mlir][transform][linalg][python] Add extended TileToForallOp. This patch adds a mixin for TileToForallOp to _structured_transform_ops_ext.py with syntactic sugar for construction such ops. First, the types of the results are made optional and filled with common default values if omitted. Second, for num_threads and tile_sizes, the three possible forms (static, dynamic, or packed), can now all be given through the same respective argument, which gets dispatched to the correct form-specific argument automatically. Reviewed By: nicolasvasilache, ftynse Differential Revision: https://reviews.llvm.org/D155090 --- .../dialects/_structured_transform_ops_ext.py | 135 +++++++++++++++++- 1 file changed, 134 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 640730997..7f90a4647 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -9,7 +9,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Union, overload +from typing import List, Optional, Sequence, Tuple, Union, overload IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] @@ -17,6 +17,47 @@ BoolOrAttrList = Sequence[Union[BoolAttr, bool]] OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] +MixedValues = Union[ + Sequence[Union[int, IntegerAttr, Operation, Value, OpView]], + ArrayAttr, + Operation, + Value, + OpView, +] + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(_get_op_result_or_value(size)) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + def _get_int_int_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] @@ -354,6 +395,98 @@ def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: return [element for element in attr] +class TileToForallOp: + """Specialization for TileToForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, Union[Operation, Value, OpView] # loops_type + ], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError( + "If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well." + ) + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + ( + dynamic_num_threads, + packed_num_threads, + num_threads_attr, + ) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + ( + dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr, + ) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + class VectorizeOp: """Specialization for VectorizeOp class.""" From c6cc1ecc21a28f3917c18a7dc10e127540a486df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 19 Jul 2023 11:55:00 +0000 Subject: [PATCH 510/915] [mlir][transform][linalg][python] Add mix-in for FuseIntoContainingOp. The class did not have any mix-in until now. The new mix-in has two overloads for the constructor of the class: one with all arguments and one without the result types, which are defaulted to `AnyOpType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155695 --- .../dialects/_structured_transform_ops_ext.py | 64 +++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 7f90a4647..1936f4b0e 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -93,6 +93,70 @@ def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): ) +class FuseIntoContainingOp: + """Specialization for FuseIntoContainingOp class.""" + + @overload + def __init__( + self, + fused_op_type: Type, + new_containing_op_type: Type, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], + new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], + producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(fused_op_type_or_producer_op, Type): + if not isinstance(new_containing_op_type_or_containing_op, Type): + raise TypeError( + "If 'fused_op_type_or_producer_op' is a type, then " + "'new_containing_op_type_or_containing_op' is expected " + "to be one as well." + ) + fused_op_type = fused_op_type_or_producer_op + new_containing_op_type = new_containing_op_type_or_containing_op + producer_op = producer_op_or_none + containing_op = containing_op_or_none + else: + fused_op_type = transform.AnyOpType.get() + new_containing_op_type = transform.AnyOpType.get() + producer_op = fused_op_type_or_producer_op + containing_op = new_containing_op_type_or_containing_op + + super().__init__( + fused_op_type, + new_containing_op_type, + producer_op, + containing_op, + loc=loc, + ip=ip, + ) + + class GeneralizeOp: """Specialization for GeneralizeOp class.""" From 4dba457854dc292c8098f688624881e145ab8e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 18 Jul 2023 15:02:32 +0000 Subject: [PATCH 511/915] [mlir][transform][gpu][python] Add .td file for bindings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155602 --- mlir/python/CMakeLists.txt | 9 +++++++++ mlir/python/mlir/dialects/GPUTransformOps.td | 20 ++++++++++++++++++++ mlir/python/mlir/dialects/transform/gpu.py | 5 +++++ 3 files changed, 34 insertions(+) create mode 100644 mlir/python/mlir/dialects/GPUTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/gpu.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 22a55dbb6..e5d37b228 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -143,6 +143,15 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME bufferization_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/GPUTransformOps.td + SOURCES + dialects/transform/gpu.py + DIALECT_NAME transform + EXTENSION_NAME gpu_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/GPUTransformOps.td b/mlir/python/mlir/dialects/GPUTransformOps.td new file mode 100644 index 000000000..08bd9537b --- /dev/null +++ b/mlir/python/mlir/dialects/GPUTransformOps.td @@ -0,0 +1,20 @@ +//===-- GPUTransformOps.td ---------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the GPU dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_GPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_GPU_TRANSFORM_OPS + +include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.td" + +#endif // PYTHON_BINDINGS_GPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py new file mode 100644 index 000000000..8c3de0de7 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/gpu.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._gpu_transform_ops_gen import * From 4960fede190d62ad560b119e65fa2baad1c77437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 19 Jul 2023 15:31:23 +0000 Subject: [PATCH 512/915] [mlir][transform][gpu][python] Add MapForallToBlocks mix-in. This patch adds a mix-in class for MapForallToBlocks with overloaded constructors. This makes it optional to provide the return type of the op, which is defaulte to `AnyOpType`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155717 --- mlir/python/CMakeLists.txt | 1 + .../mlir/dialects/_gpu_transform_ops_ext.py | 69 +++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 mlir/python/mlir/dialects/_gpu_transform_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index e5d37b228..50fbca38a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -148,6 +148,7 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUTransformOps.td SOURCES + dialects/_gpu_transform_ops_ext.py dialects/transform/gpu.py DIALECT_NAME transform EXTENSION_NAME gpu_transform) diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py new file mode 100644 index 000000000..087606e3d --- /dev/null +++ b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py @@ -0,0 +1,69 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +class MapForallToBlocks: + """Specialization for MapForallToBlocks class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Sequence[int]] = None, + generate_gpu_launch: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Sequence[int]] = None, + generate_gpu_launch: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Type, Value], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + grid_dims: Optional[Sequence[int]] = None, + generate_gpu_launch: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + + if grid_dims is not None and not isinstance(grid_dims, ArrayAttr): + grid_dims = DenseI64ArrayAttr.get(grid_dims) + + super().__init__( + result_type, + target, + grid_dims=grid_dims, + generate_gpu_launch=generate_gpu_launch, + loc=loc, + ip=ip, + ) From 3f11964625d5fd0ba4e1238153dafce7ed0bb369 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 17 Jul 2023 10:26:33 +0000 Subject: [PATCH 513/915] [mlir][transform][python] Add extended ApplyPatternsOp. This patch adds a mixin for ApplyPatternsOp to _transform_ops_ext.py with syntactic sugar for construction such ops. Curiously, the op did not have any constructors yet, probably because its tablegen definition said to skip the default builders. The new constructor is thus quite straightforward. The commit also adds a refined `region` property which returns the first block of the single region. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155435 --- .../mlir/dialects/_transform_ops_ext.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 87f8d398c..0db2e3bd9 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -29,6 +29,32 @@ def __init__( ) +class ApplyPatternsOp: + + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + operands = [] + operands.append(_get_op_result_or_value(target)) + super().__init__( + self.build_generic(attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip)) + self.regions[0].blocks.append() + + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] + + class testGetParentOp: def __init__( From ac83d7ff761611cfce3865ada9c98a72bacc4498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 20 Jul 2023 09:58:41 +0000 Subject: [PATCH 514/915] [mlir][transform][structured][python] Allow str arg in match_op_names. Allow the `names` argument in `MatchOp.match_op_names` to be of type `str` in addition to `Sequence[str]`. In this case, the argument is treated as a list with one name, i.e., it is possible to write `MatchOp.match_op_names(..., "test.dummy")` instead of `MatchOp.match_op_names(..., ["test.dummy"])`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D155807 --- .../mlir/dialects/_structured_transform_ops_ext.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 1936f4b0e..9f623efb5 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -195,7 +195,7 @@ class MatchOp: def match_op_names( cls, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -208,7 +208,7 @@ def match_op_names( cls, result_type: Type, target: Union[Operation, Value], - names: Sequence[str], + names: Union[str, Sequence[str]], *, loc=None, ip=None, @@ -219,8 +219,8 @@ def match_op_names( def match_op_names( cls, result_type_or_target: Union[Type, Operation, Value], - target_or_names: Union[Operation, Value, Sequence[str]], - names_or_none: Optional[Sequence[str]] = None, + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, *, loc=None, ip=None, @@ -234,6 +234,9 @@ def match_op_names( target = result_type_or_target names = target_or_names + if isinstance(names, str): + names = [names] + return cls( result_type, _get_op_result_or_value(target), From 5637a25365ccd65b7b3fb2c91b0eacc4413bef04 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 23 Jul 2023 21:26:52 -0700 Subject: [PATCH 515/915] [mlir][py] Reuse more of CAPI build time inference. This reduces code generated for type inference and instead reuses facilities CAPI side that performed same role. Differential Revision: https://reviews.llvm.org/D156041t --- mlir/lib/Bindings/Python/IRCore.cpp | 110 ++++++++++++++++------------ mlir/lib/Bindings/Python/IRModule.h | 16 ++-- 2 files changed, 70 insertions(+), 56 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 39049f387..971d2819a 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -78,6 +78,7 @@ static const char kOperationCreateDocstring[] = ip: An InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager). + infer_type: Whether to infer result types. Returns: A new "detached" Operation object. Detached operations can be added to blocks, which causes them to become "attached." @@ -1288,7 +1289,7 @@ py::object PyOperation::create(const std::string &name, std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const py::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1367,6 +1368,7 @@ py::object PyOperation::create(const std::string &name, if (!mlirOperands.empty()) mlirOperationStateAddOperands(&state, mlirOperands.size(), mlirOperands.data()); + state.enableResultTypeInference = inferType; if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), mlirResults.data()); @@ -1398,6 +1400,8 @@ py::object PyOperation::create(const std::string &name, // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); + if (!operation.ptr) + throw py::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1441,51 +1445,10 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -py::object -PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, - py::list operandList, std::optional attributes, - std::optional> successors, - std::optional regions, - DefaultingPyLocation location, - const py::object &maybeIp) { - PyMlirContextRef context = location->getContext(); - // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); - // Operand and result segment specs are either none, which does no - // variadic unpacking, or a list of ints with segment sizes, where each - // element is either a positive number (typically 1 for a scalar) or -1 to - // indicate that it is derived from the length of the same-indexed operand - // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - - std::vector operandSegmentLengths; - std::vector resultSegmentLengths; - - // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); - int opMinRegionCount = std::get<0>(opRegionSpec); - bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); - if (!regions) { - regions = opMinRegionCount; - } - if (*regions < opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( - (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + - llvm::Twine(opMinRegionCount) + - " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); - } - - // Unpack results. - std::vector resultTypes; +static void populateResultTypes(StringRef name, py::list resultTypeList, + const py::object &resultSegmentSpecObj, + std::vector &resultSegmentLengths, + std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); if (resultSegmentSpecObj.is_none()) { // Non-variadic result unpacking. @@ -1568,6 +1531,56 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, } } } +} + +py::object PyOpView::buildGeneric( + const py::object &cls, std::optional resultTypeList, + py::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const py::object &maybeIp) { + PyMlirContextRef context = location->getContext(); + // Class level operation construction metadata. + std::string name = py::cast(cls.attr("OPERATION_NAME")); + // Operand and result segment specs are either none, which does no + // variadic unpacking, or a list of ints with segment sizes, where each + // element is either a positive number (typically 1 for a scalar) or -1 to + // indicate that it is derived from the length of the same-indexed operand + // or result (implying that it is a list at that position). + py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + + std::vector operandSegmentLengths; + std::vector resultSegmentLengths; + + // Validate/determine region count. + auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + int opMinRegionCount = std::get<0>(opRegionSpec); + bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); + if (!regions) { + regions = opMinRegionCount; + } + if (*regions < opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + if (opHasNoVariadicRegions && *regions > opMinRegionCount) { + throw py::value_error( + (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + + llvm::Twine(opMinRegionCount) + + " regions but was built with regions=" + llvm::Twine(*regions)) + .str()); + } + + // Unpack results. + std::vector resultTypes; + if (resultTypeList.has_value()) { + populateResultTypes(name, *resultTypeList, resultSegmentSpecObj, + resultSegmentLengths, resultTypes); + } // Unpack operands. std::vector operands; @@ -1694,7 +1707,8 @@ PyOpView::buildGeneric(const py::object &cls, py::list resultTypeList, /*operands=*/std::move(operands), /*attributes=*/std::move(attributes), /*successors=*/std::move(successors), - /*regions=*/*regions, location, maybeIp); + /*regions=*/*regions, location, maybeIp, + !resultTypeList); } pybind11::object PyOpView::constructDerived(const pybind11::object &cls, @@ -2854,7 +2868,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("attributes") = py::none(), py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), - kOperationCreateDocstring) + py::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 5da1d7d25..d1911730c 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -655,7 +655,8 @@ class PyOperation : public PyOperationBase, public BaseContextObject { std::optional> operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip); + DefaultingPyLocation location, const pybind11::object &ip, + bool inferType); /// Creates an OpView suitable for this operation. pybind11::object createOpView(); @@ -704,13 +705,12 @@ class PyOpView : public PyOperationBase { pybind11::object getOperationObject() { return operationObject; } - static pybind11::object - buildGeneric(const pybind11::object &cls, pybind11::list resultTypeList, - pybind11::list operandList, - std::optional attributes, - std::optional> successors, - std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + static pybind11::object buildGeneric( + const pybind11::object &cls, std::optional resultTypeList, + pybind11::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const pybind11::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor From db0b47d55a95fb19bb9042ae6d73ff1f132aaf69 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 23 Jul 2023 21:40:12 -0700 Subject: [PATCH 516/915] [mlir] Enable converting properties during C create This enables querying properties passed as attributes during construction time. In particular needed for type inference where the Operation has not been created at this point. This allows Python construction of operations whose type inference depends on properties. Differential Revision: https://reviews.llvm.org/D156070 --- mlir/lib/CAPI/IR/IR.cpp | 36 ++++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index dedae3ddd..b140a4639 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -15,10 +15,12 @@ #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Utils.h" #include "mlir/IR/Attributes.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -26,6 +28,7 @@ #include "mlir/Parser/Parser.h" #include +#include #include using namespace mlir; @@ -345,25 +348,43 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (!info) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation was not registered. Ensure that the dialect " + << ", but the operation was not registered; ensure that the dialect " "containing the operation is linked into MLIR and registered with " "the context"; return failure(); } - // Fallback to inference via an op interface. auto *inferInterface = info->getInterface(); if (!inferInterface) { emitError(state.location) << "type inference was requested for the operation " << state.name - << ", but the operation does not support type inference. Result " - "types must be specified explicitly."; + << ", but the operation does not support type inference; result " + "types must be specified explicitly"; + return failure(); + } + + DictionaryAttr attributes = state.attributes.getDictionary(context); + OpaqueProperties properties = state.getRawProperties(); + + if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { + auto prop = std::make_unique(info->getOpPropertyByteSize()); + properties = OpaqueProperties(prop.get()); + if (failed(info->setOpPropertiesFromAttribute(state.name, properties, + attributes, nullptr))) { + return failure(); + } + + if (succeeded(inferInterface->inferReturnTypes( + context, state.location, state.operands, attributes, properties, + state.regions, state.types))) { + return success(); + } + // Diagnostic emitted by interface. return failure(); } if (succeeded(inferInterface->inferReturnTypes( - context, state.location, state.operands, - state.attributes.getDictionary(context), state.getRawProperties(), + context, state.location, state.operands, attributes, properties, state.regions, state.types))) return success(); @@ -405,8 +426,7 @@ MlirOperation mlirOperationCreate(MlirOperationState *state) { return {nullptr}; } - MlirOperation result = wrap(Operation::create(cppState)); - return result; + return wrap(Operation::create(cppState)); } MlirOperation mlirOperationCreateParse(MlirContext context, From 5c71c6179098330137c93d6b77f7616e2c0a2381 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 20 Jul 2023 22:51:35 -0700 Subject: [PATCH 517/915] Update ODS variadic segments "magic" attributes to use native Properties The operand_segment_sizes and result_segment_sizes Attributes are now inlined in the operation as native propertie. We continue to support building an Attribute on the fly for `getAttr("operand_segment_sizes")` and setting the property from an attribute with `setAttr("operand_segment_sizes", attr)`. A new bytecode version is introduced to support backward compatibility and backdeployments. Differential Revision: https://reviews.llvm.org/D155919 --- mlir/lib/CAPI/IR/IR.cpp | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index b140a4639..5231fe5f9 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" @@ -369,10 +370,15 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { auto prop = std::make_unique(info->getOpPropertyByteSize()); properties = OpaqueProperties(prop.get()); + InFlightDiagnostic diag = emitError(state.location) + << " failed properties conversion while building " + << state.name.getStringRef() << " with `" + << attributes << "`: "; if (failed(info->setOpPropertiesFromAttribute(state.name, properties, - attributes, nullptr))) { + attributes, &diag))) { return failure(); } + diag.abandon(); if (succeeded(inferInterface->inferReturnTypes( context, state.location, state.operands, attributes, properties, From 7e0cf96a32b1b6381ffa4b4dd74a621de79656c2 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 25 Jul 2023 23:17:00 -0500 Subject: [PATCH 518/915] add set_type to ir.Value Differential Revision: https://reviews.llvm.org/D156289 --- mlir/include/mlir-c/IR.h | 3 +++ mlir/lib/Bindings/Python/IRCore.cpp | 6 ++++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 13 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 5312db091..b5c6a3094 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -801,6 +801,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirOpResultGetResultNumber(MlirValue value); /// Returns the type of the value. MLIR_CAPI_EXPORTED MlirType mlirValueGetType(MlirValue value); +/// Set the type of the value. +MLIR_CAPI_EXPORTED void mlirValueSetType(MlirValue value, MlirType type); + /// Prints the value to the standard error stream. MLIR_CAPI_EXPORTED void mlirValueDump(MlirValue value); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 971d2819a..6b9de9a8c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3431,6 +3431,12 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("use_local_scope") = false, kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) + .def( + "set_type", + [](PyValue &self, const PyType &type) { + return mlirValueSetType(self.get(), type); + }, + py::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 5231fe5f9..ccdae1424 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -823,6 +823,10 @@ MlirType mlirValueGetType(MlirValue value) { return wrap(unwrap(value).getType()); } +void mlirValueSetType(MlirValue value, MlirType type) { + unwrap(value).setType(unwrap(type)); +} + void mlirValueDump(MlirValue value) { unwrap(value).dump(); } void mlirValuePrint(MlirValue value, MlirStringCallback callback, From 5fcf2914b085f79ae2026ab1265f7e4ca5b5b1e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 20 Jul 2023 08:27:11 +0000 Subject: [PATCH 519/915] [mlir][transform][bufferization][python] Add mix-in classes for two ops. This patch adds mix-in classes for the Python bindings of `EmptyTensorToAllocTensorOp` and `OneShotBufferizeOp`. For both classes, the mix-in add overloads to the `__init__` functions that allow to construct them without providing the return type, which is defaulted to the only allowed type and `AnyOpType`, respectively. Note that the mix-in do not expose the `function_boundary_type_conversion` attribute. The attribute has a custom type from the bufferization dialect that is currently not exposed in the Python bindings. Handling of that attribute can be added easily to the mix-in class when the need arises. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D155799 --- mlir/python/CMakeLists.txt | 1 + .../_bufferization_transform_ops_ext.py | 114 ++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 50fbca38a..05a36cafc 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -139,6 +139,7 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/BufferizationTransformOps.td SOURCES + dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform EXTENSION_NAME bufferization_transform) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py new file mode 100644 index 000000000..77f4d1e16 --- /dev/null +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -0,0 +1,114 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +class EmptyTensorToAllocTensorOp: + """Specialization for EmptyTensorToAllocTensorOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.OperationType.get("bufferization.alloc_tensor") + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + loc=loc, + ip=ip, + ) + + +class OneShotBufferizeOp: + """Specialization for OneShotBufferizeOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + print_conflicts: Optional[bool] = None, + memcpy_op: Optional[str] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + print_conflicts: Optional[bool] = None, + memcpy_op: Optional[str] = None, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + allow_return_allocs=allow_return_allocs, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + create_deallocs=create_deallocs, + test_analysis_only=test_analysis_only, + print_conflicts=print_conflicts, + memcpy_op=memcpy_op, + loc=loc, + ip=ip, + ) From 0cf8357623393c985a0680598bba93563ad5f066 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 27 Jul 2023 12:25:59 +0000 Subject: [PATCH 520/915] [mlir] delete yapf config files, NFC LLVM has converged to using black for Python formatting. Remove the yapf configs MLIR used to rely on before that (the reformatting has already happened). --- mlir/python/.style.yapf | 4 ---- 1 file changed, 4 deletions(-) delete mode 100644 mlir/python/.style.yapf diff --git a/mlir/python/.style.yapf b/mlir/python/.style.yapf deleted file mode 100644 index 9ef1dc15b..000000000 --- a/mlir/python/.style.yapf +++ /dev/null @@ -1,4 +0,0 @@ -[style] - based_on_style = google - column_limit = 80 - indent_width = 2 From af406697d184e0df91eb14025f7c86bf13c0e7ec Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 28 Jul 2023 12:58:29 +0000 Subject: [PATCH 521/915] [mlir][python] more python gpu transform mixins Add the Python mix-in for MapNestedForallToThreads. Fix typing annotations in MapForallToBlocks and drop the attribute wrapping rendered unnecessary by attribute builders. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156528 --- .../mlir/dialects/_gpu_transform_ops_ext.py | 73 ++++++++++++++++--- 1 file changed, 64 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py index 087606e3d..ba72bac3a 100644 --- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py @@ -20,8 +20,8 @@ def __init__( result_type: Type, target: Union[Operation, OpView, Value], *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -32,8 +32,8 @@ def __init__( self, target: Union[Operation, OpView, Value], *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -44,8 +44,8 @@ def __init__( result_type_or_target: Union[Operation, OpView, Type, Value], target_or_none: Optional[Union[Operation, OpView, Value]] = None, *, - grid_dims: Optional[Sequence[int]] = None, - generate_gpu_launch: Optional[bool] = None, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, loc=None, ip=None ): @@ -56,9 +56,6 @@ def __init__( result_type = transform.AnyOpType.get() target = result_type_or_target - if grid_dims is not None and not isinstance(grid_dims, ArrayAttr): - grid_dims = DenseI64ArrayAttr.get(grid_dims) - super().__init__( result_type, target, @@ -67,3 +64,61 @@ def __init__( loc=loc, ip=ip, ) + + +class MapNestedForallToThreads: + """Specialization for MapNestedForallToThreads class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Value, Type], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + block_dims: Optional[Union[Sequence[int], Attribute]] = None, + warp_size: Optional[Union[Sequence[int], Attribute]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = result_type_or_target.type + target = result_type_or_target + super().__init__( + result_type, + target, + block_dims=block_dims, + warp_size=warp_size, + sync_after_distribute=sync_after_distribute, + loc=loc, + ip=ip, + ) From b065611db71c8bc6acc8d675498a8f26426dda91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 28 Jul 2023 13:47:23 +0000 Subject: [PATCH 522/915] [mlir][memref][transform][python] Create .td file for bindings. This patch creates the .td files for the Python bindings of the transform ops of the MemRef dialect and integrates them into the build systems (CMake and Bazel). Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156536 --- mlir/python/CMakeLists.txt | 9 +++++++++ mlir/python/mlir/dialects/MemRefTransformOps.td | 14 ++++++++++++++ mlir/python/mlir/dialects/transform/memref.py | 5 +++++ 3 files changed, 28 insertions(+) create mode 100644 mlir/python/mlir/dialects/MemRefTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/memref.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 05a36cafc..d233194b1 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -164,6 +164,15 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME loop_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/MemRefTransformOps.td + SOURCES + dialects/transform/memref.py + DIALECT_NAME transform + EXTENSION_NAME memref_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/MemRefTransformOps.td b/mlir/python/mlir/dialects/MemRefTransformOps.td new file mode 100644 index 000000000..a64c2e238 --- /dev/null +++ b/mlir/python/mlir/dialects/MemRefTransformOps.td @@ -0,0 +1,14 @@ +//===-- MemRefTransformOps.td - Memref transform ops -------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_MEMREF_TRANSFORM_OPS +#define PYTHON_BINDINGS_MEMREF_TRANSFORM_OPS + +include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py new file mode 100644 index 000000000..1ff04ef6a --- /dev/null +++ b/mlir/python/mlir/dialects/transform/memref.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._memref_transform_ops_gen import * From 4a4d5e98de51658abbd0cd7cd1e48e1071eaa45b Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 28 Jul 2023 16:03:33 +0000 Subject: [PATCH 523/915] [mlir] python enum bindings generator Add an ODS (tablegen) backend to generate Python enum classes and attribute builders for enum attributes defined in ODS. This will allow us to keep the enum attribute definitions in sync between C++ and Python, as opposed to handwritten enum classes in Python that may end up using mismatching values. This also makes autogenerated bindings more convenient even in absence of mixins. Use this backend for the transform dialect failure propagation mode enum attribute as demonstration. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156553 --- mlir/python/CMakeLists.txt | 9 ++ .../mlir/dialects/_transform_ops_ext.py | 116 ++++++++---------- .../mlir/dialects/transform/__init__.py | 18 +-- 3 files changed, 64 insertions(+), 79 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d233194b1..3263bc1db 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -134,6 +134,15 @@ declare_mlir_dialect_python_bindings( _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform) +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") +mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRTransformDialectPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + SOURCES "dialects/_transform_enum_gen.py") + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 0db2e3bd9..b1e7b8925 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -15,68 +15,66 @@ class CastOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, _get_op_result_or_value(target), loc=loc, ip=ip - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) class ApplyPatternsOp: + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + operands = [] + operands.append(_get_op_result_or_value(target)) + super().__init__( + self.build_generic( + attributes={}, + results=[], + operands=operands, + successors=None, + regions=None, + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append() - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - loc=None, - ip=None, - ): - operands = [] - operands.append(_get_op_result_or_value(target)) - super().__init__( - self.build_generic(attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip)) - self.regions[0].blocks.append() - - @property - def patterns(self) -> Block: - return self.regions[0].blocks[0] + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] class testGetParentOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) class MergeHandlesOp: @@ -130,12 +128,6 @@ def __init__( else None ) root_type = root.type if not isinstance(target, Type) else target - if not isinstance(failure_propagation_mode, Attribute): - failure_propagation_mode_attr = IntegerAttr.get( - IntegerType.get_signless(32), failure_propagation_mode._as_int() - ) - else: - failure_propagation_mode_attr = failure_propagation_mode if extra_bindings is None: extra_bindings = [] @@ -152,7 +144,7 @@ def __init__( super().__init__( results_=results, - failure_propagation_mode=failure_propagation_mode_attr, + failure_propagation_mode=failure_propagation_mode, root=root, extra_bindings=extra_bindings, ) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index b505a490a..b020ad35f 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -2,22 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from enum import Enum - - -class FailurePropagationMode(Enum): - """Propagation mode for silenceable errors.""" - - PROPAGATE = 1 - SUPPRESS = 2 - - def _as_int(self): - if self is FailurePropagationMode.PROPAGATE: - return 1 - - assert self is FailurePropagationMode.SUPPRESS - return 2 - - +from .._transform_enum_gen import * from .._transform_ops_gen import * from ..._mlir_libs._mlirDialectsTransform import * From d5078cc941c73fd4d1bd072fa40238a705b4dd4d Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Fri, 28 Jul 2023 16:06:31 +0000 Subject: [PATCH 524/915] [mlir] python bindings for vector transform ops Provide Python bindings for transform ops defined in the vector dialect. All of these ops are sufficiently simple that no mixins are necessary for them to be nicely usable. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156554 --- mlir/python/CMakeLists.txt | 18 ++++++++++++++++++ .../mlir/dialects/VectorTransformOps.td | 19 +++++++++++++++++++ mlir/python/mlir/dialects/transform/vector.py | 6 ++++++ 3 files changed, 43 insertions(+) create mode 100644 mlir/python/mlir/dialects/VectorTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/vector.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 3263bc1db..d9c1a98bc 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -192,6 +192,24 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME structured_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/VectorTransformOps.td + SOURCES + dialects/transform/vector.py + DIALECT_NAME transform + EXTENSION_NAME vector_transform) + +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") +mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRVectorTransformPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.vector_transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform + SOURCES "dialects/_vector_transform_enum_gen.py" ) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/VectorTransformOps.td b/mlir/python/mlir/dialects/VectorTransformOps.td new file mode 100644 index 000000000..42aa8c006 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorTransformOps.td @@ -0,0 +1,19 @@ +//===-- VectorTransformOps.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the vector transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_VECTORTRANSFORMOPS +#define PYTHON_BINDINGS_VECTORTRANSFORMOPS + +include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.td" + +#endif // PYTHON_BINDINGS_VECTORTRANSFORMOPS diff --git a/mlir/python/mlir/dialects/transform/vector.py b/mlir/python/mlir/dialects/transform/vector.py new file mode 100644 index 000000000..af2435cb2 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/vector.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._vector_transform_enum_gen import * +from .._vector_transform_ops_gen import * From 4aeced8920b81c267e648f97867fdcddb22ff7c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 28 Jul 2023 17:20:06 +0000 Subject: [PATCH 525/915] [mlir][memref][transform][python] Create mix-in for MemRefMultiBufferOp. Create a mix-in class with an overloaded constructor that makes the return type optional. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156561 --- mlir/python/CMakeLists.txt | 1 + .../dialects/_memref_transform_ops_ext.py | 68 +++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 mlir/python/mlir/dialects/_memref_transform_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d9c1a98bc..a2aa493e2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -178,6 +178,7 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MemRefTransformOps.td SOURCES + dialects/_memref_transform_ops_ext.py dialects/transform/memref.py DIALECT_NAME transform EXTENSION_NAME memref_transform) diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py new file mode 100644 index 000000000..4afe8e7b8 --- /dev/null +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -0,0 +1,68 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +class MemRefMultiBufferOp: + """Specialization for MemRefMultiBufferOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, + factor_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_factor + factor = factor_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + factor = target_or_factor + + super().__init__( + transformed_type, + target, + factor, + skip_analysis=skip_analysis, + loc=loc, + ip=ip, + ) From a290c1929f3e162d404b6054bf86a511560fb65a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 31 Jul 2023 10:24:45 +0000 Subject: [PATCH 526/915] [mlir][bufferization][transform][python] Add enums to bindings & mixins. This patch uses the new enum binding generation to add the enums of the dialect to the Python bindings and uses them in the mix-in class where it was still missing (namely, the `LayoutMapOption` for the `function_boundary_type_conversion` of the `OneShotBufferizeOp`. The patch also piggy-backs a few smaller clean-ups: * Order the keyword-only arguments alphabetically. * Add the keyword-only arguments to an overload where they were left out by accident. * Change some of the attribute values used in the tests to non-default values such that they show up in the output IR and check for that output. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156664 --- mlir/python/CMakeLists.txt | 9 ++++++ .../_bufferization_transform_ops_ext.py | 32 +++++++++++++++---- .../mlir/dialects/transform/bufferization.py | 1 + 3 files changed, 35 insertions(+), 7 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index a2aa493e2..656e3f895 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -153,6 +153,15 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME bufferization_transform) +set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") +mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) +add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.bufferization_transform.enum_gen + ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" + ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform + SOURCES "dialects/_bufferization_transform_enum_gen.py") + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py index 77f4d1e16..ead337282 100644 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -8,6 +8,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e +from enum import Enum from typing import Optional, overload, Union @@ -65,16 +66,31 @@ def __init__( allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, create_deallocs: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - print_conflicts: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, loc=None, ip=None ): ... @overload - def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): ... def __init__( @@ -86,9 +102,10 @@ def __init__( allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, create_deallocs: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - print_conflicts: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, loc=None, ip=None ): @@ -106,9 +123,10 @@ def __init__( allow_unknown_ops=allow_unknown_ops, bufferize_function_boundaries=bufferize_function_boundaries, create_deallocs=create_deallocs, - test_analysis_only=test_analysis_only, - print_conflicts=print_conflicts, + function_boundary_type_conversion=function_boundary_type_conversion, memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, loc=loc, ip=ip, ) diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py index eb77b746c..1891bc0e1 100644 --- a/mlir/python/mlir/dialects/transform/bufferization.py +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -2,4 +2,5 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from .._bufferization_transform_enum_gen import * from .._bufferization_transform_ops_gen import * From 3aaee6b613f800ad879ce64ef661cdd716ac130a Mon Sep 17 00:00:00 2001 From: Jerry Wu Date: Fri, 28 Jul 2023 21:46:58 +0000 Subject: [PATCH 527/915] [mlir] Add linalg.batch_mmt4d named op This op is the batched version of linalg.mmt4d. It performs matrix-matrix-transpose multiplication of batched 4-d (5d) inputs as the following: ``` C[b, m1, n1, m0, n0] = sum_{b, k1, k0}(A[b, m1, k1, m0, k0] * B[b, n1, k1, n0, k0]) ``` The current use is to provide `linalg.batch_matmul` a lowering path similar to `linalg.matmul -> linalg.mmt4d`. Differential Revision: https://reviews.llvm.org/D156912 --- .../linalg/opdsl/ops/core_named_ops.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 08818b212..dee0c3e3f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -350,6 +350,27 @@ def mmt4d( ) * TypeFn.cast_signed(TV.AccumType, rhs[D.n, D.k, D.n0, D.k0]) +@linalg_structured_op +def batch_mmt4d( + lhs=TensorDef(TV.LhsType, Batch, S.M, S.K, S.M0, S.K0), + rhs=TensorDef(TV.RhsType, Batch, S.N, S.K, S.N0, S.K0), + accum=TensorDef(TV.AccumType, Batch, S.M, S.N, S.M0, S.N0, output=True), +): + """Performs a batched matrix-matrix-transpose multiplication of two + batched-4D (5D) inputs. + + Besides the outermost batch dimension has the same semantic as + linalg.batch_matmul, the differences from linalg.batch_matmul in the + non-batch dimensions are the same as linalg.mmt4d vs. linalg.matmul. See the + description of lingalg.mmt4d. + """ + domain(D.b, D.m, D.n, D.k, D.m0, D.n0, D.k0) + implements(ContractionOpInterface) + accum[D.b, D.m, D.n, D.m0, D.n0] += TypeFn.cast_signed( + TV.AccumType, lhs[D.b, D.m, D.k, D.m0, D.k0] + ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) + + @linalg_structured_op def batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), From ee245d4ecf383a3f1f8401276b5e0787d298ed6b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 2 Aug 2023 16:58:44 +0000 Subject: [PATCH 528/915] [mlir][transform][tensor][python] Add .td files for bindings. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156914 --- mlir/python/CMakeLists.txt | 9 +++++++++ .../mlir/dialects/TensorTransformOps.td | 20 +++++++++++++++++++ mlir/python/mlir/dialects/transform/tensor.py | 5 +++++ 3 files changed, 34 insertions(+) create mode 100644 mlir/python/mlir/dialects/TensorTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/tensor.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 656e3f895..1c2351736 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -202,6 +202,15 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME structured_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TensorTransformOps.td + SOURCES + dialects/transform/tensor.py + DIALECT_NAME transform + EXTENSION_NAME tensor_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TensorTransformOps.td b/mlir/python/mlir/dialects/TensorTransformOps.td new file mode 100644 index 000000000..87c5c7f39 --- /dev/null +++ b/mlir/python/mlir/dialects/TensorTransformOps.td @@ -0,0 +1,20 @@ +//===-- TensorTransformOps.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the tensor dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS +#define PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS + +include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.td" + +#endif // PYTHON_BINDINGS_TENSOR_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py new file mode 100644 index 000000000..bf52255b3 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/tensor.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._tensor_transform_ops_gen import * From fcfb017889236a5734eae2693ecd5bee78b16ef8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 2 Aug 2023 17:24:53 +0000 Subject: [PATCH 529/915] [mlir][tensor][transform][python] Add mix-in class. This patch adds a mix-in class for the only transform op of the tensor dialect that can benefit from one: the MakeLoopIndependentOp. It adds an overload that makes providing the return type optional. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D156918 --- mlir/python/CMakeLists.txt | 1 + .../dialects/_tensor_transform_ops_ext.py | 64 +++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 mlir/python/mlir/dialects/_tensor_transform_ops_ext.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 1c2351736..0fae5dbb8 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -207,6 +207,7 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorTransformOps.td SOURCES + dialects/_tensor_transform_ops_ext.py dialects/transform/tensor.py DIALECT_NAME transform EXTENSION_NAME tensor_transform) diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py new file mode 100644 index 000000000..996093fbc --- /dev/null +++ b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py @@ -0,0 +1,64 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ..dialects import transform +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +class MakeLoopIndependentOp: + """Specialization for MakeLoopIndependentOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, + num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_num_loops + num_loops = num_loops_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + num_loops = target_or_num_loops + + super().__init__( + transformed_type, + target, + num_loops, + loc=loc, + ip=ip, + ) From c73c2474dd3cf71b13abca266bd54e4aa382c80b Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 9 Aug 2023 12:47:13 -0700 Subject: [PATCH 530/915] Finish renaming getOperandSegmentSizeAttr() from `operand_segment_sizes` to `operandSegmentSizes` This renaming started with the native ODS support for properties, this is completing it. A mass automated textual rename seems safe for most codebases. Drop also the ods prefix to keep the accessors the same as they were before this change: properties.odsOperandSegmentSizes reverts back to: properties.operandSegementSizes The ODS prefix was creating divergence between all the places and make it harder to be consistent. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D157173 --- mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 6b9de9a8c..d348175b5 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1675,28 +1675,28 @@ py::object PyOpView::buildGeneric( } else { attributes = py::dict(); } - if (attributes->contains("result_segment_sizes") || - attributes->contains("operand_segment_sizes")) { - throw py::value_error("Manually setting a 'result_segment_sizes' or " - "'operand_segment_sizes' attribute is unsupported. " + if (attributes->contains("resultSegmentSizes") || + attributes->contains("operandSegmentSizes")) { + throw py::value_error("Manually setting a 'resultSegmentSizes' or " + "'operandSegmentSizes' attribute is unsupported. " "Use Operation.create for such low-level access."); } - // Add result_segment_sizes attribute. + // Add resultSegmentSizes attribute. if (!resultSegmentLengths.empty()) { MlirAttribute segmentLengthAttr = mlirDenseI32ArrayGet(context->get(), resultSegmentLengths.size(), resultSegmentLengths.data()); - (*attributes)["result_segment_sizes"] = + (*attributes)["resultSegmentSizes"] = PyAttribute(context, segmentLengthAttr); } - // Add operand_segment_sizes attribute. + // Add operandSegmentSizes attribute. if (!operandSegmentLengths.empty()) { MlirAttribute segmentLengthAttr = mlirDenseI32ArrayGet(context->get(), operandSegmentLengths.size(), operandSegmentLengths.data()); - (*attributes)["operand_segment_sizes"] = + (*attributes)["operandSegmentSizes"] = PyAttribute(context, segmentLengthAttr); } } From 1d64b2bec306d84161b4b2603d01f41512d5321c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 10 Aug 2023 18:17:35 +0000 Subject: [PATCH 531/915] [mlir][transform][python] Add AnyValueType to bindings. This patch adds the MLIR C bindings and the corresponding Python bindings of the AnyValueType of the transform dialect. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157638 --- mlir/include/mlir-c/Dialect/Transform.h | 8 ++++++++ mlir/lib/Bindings/Python/DialectTransform.cpp | 14 ++++++++++++++ mlir/lib/CAPI/Dialect/Transform.cpp | 12 ++++++++++++ 3 files changed, 34 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h index 0409890b2..954575925 100644 --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx); + //===---------------------------------------------------------------------===// // OperationType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index e4b8cee73..932e40220 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { "Get an instance of AnyOpType in the given context.", py::arg("cls"), py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// + // AnyValueType + //===-------------------------------------------------------------------===// + + auto anyValueType = + mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType); + anyValueType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirTransformAnyValueTypeGet(ctx)); + }, + "Get an instance of AnyValueType in the given context.", py::arg("cls"), + py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// // OperationType //===-------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index d3cd4e3d0..5841f6783 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { return wrap(transform::AnyOpType::get(unwrap(ctx))); } +//===---------------------------------------------------------------------===// +// AnyValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) { + return wrap(transform::AnyValueType::get(unwrap(ctx))); +} + //===---------------------------------------------------------------------===// // OperationType //===---------------------------------------------------------------------===// From 7ec76431ec5fe2b104ec1720d738b9852eac6a28 Mon Sep 17 00:00:00 2001 From: max Date: Fri, 11 Aug 2023 20:39:56 -0500 Subject: [PATCH 532/915] add `owner` to OpResultsList. this is useful for when the list is empty and an element can't be used to fetch the owner. Differential Revision: https://reviews.llvm.org/D157769 --- mlir/lib/Bindings/Python/IRCore.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d348175b5..e1b8d296a 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2277,6 +2277,9 @@ class PyOpResultList : public Sliceable { c.def_property_readonly("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); + c.def_property_readonly("owner", [](PyOpResultList &self) { + return self.operation->createOpView(); + }); } private: From e62d6138672c7a40c270c06c9d046aa999be508a Mon Sep 17 00:00:00 2001 From: max Date: Sun, 13 Aug 2023 13:30:28 -0500 Subject: [PATCH 533/915] [MLIR][python bindings] add vendor gpu dialects Differential Revision: https://reviews.llvm.org/D157820 --- mlir/include/mlir-c/Dialect/AMDGPU.h | 25 +++++++++++++++++ mlir/include/mlir-c/Dialect/NVGPU.h | 25 +++++++++++++++++ mlir/include/mlir-c/Dialect/NVVM.h | 25 +++++++++++++++++ mlir/include/mlir-c/Dialect/ROCDL.h | 25 +++++++++++++++++ mlir/lib/CAPI/Dialect/AMDGPU.cpp | 14 ++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 37 ++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/NVGPU.cpp | 13 +++++++++ mlir/lib/CAPI/Dialect/NVVM.cpp | 13 +++++++++ mlir/lib/CAPI/Dialect/ROCDL.cpp | 13 +++++++++ mlir/python/CMakeLists.txt | 32 ++++++++++++++++++++++ mlir/python/mlir/dialects/AMDGPUOps.td | 14 ++++++++++ mlir/python/mlir/dialects/NVGPUOps.td | 14 ++++++++++ mlir/python/mlir/dialects/NVVMOps.td | 14 ++++++++++ mlir/python/mlir/dialects/ROCDLOps.td | 14 ++++++++++ mlir/python/mlir/dialects/amdgpu.py | 5 ++++ mlir/python/mlir/dialects/nvgpu.py | 5 ++++ mlir/python/mlir/dialects/nvvm.py | 5 ++++ mlir/python/mlir/dialects/rocdl.py | 5 ++++ 18 files changed, 298 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/AMDGPU.h create mode 100644 mlir/include/mlir-c/Dialect/NVGPU.h create mode 100644 mlir/include/mlir-c/Dialect/NVVM.h create mode 100644 mlir/include/mlir-c/Dialect/ROCDL.h create mode 100644 mlir/lib/CAPI/Dialect/AMDGPU.cpp create mode 100644 mlir/lib/CAPI/Dialect/NVGPU.cpp create mode 100644 mlir/lib/CAPI/Dialect/NVVM.cpp create mode 100644 mlir/lib/CAPI/Dialect/ROCDL.cpp create mode 100644 mlir/python/mlir/dialects/AMDGPUOps.td create mode 100644 mlir/python/mlir/dialects/NVGPUOps.td create mode 100644 mlir/python/mlir/dialects/NVVMOps.td create mode 100644 mlir/python/mlir/dialects/ROCDLOps.td create mode 100644 mlir/python/mlir/dialects/amdgpu.py create mode 100644 mlir/python/mlir/dialects/nvgpu.py create mode 100644 mlir/python/mlir/dialects/nvvm.py create mode 100644 mlir/python/mlir/dialects/rocdl.py diff --git a/mlir/include/mlir-c/Dialect/AMDGPU.h b/mlir/include/mlir-c/Dialect/AMDGPU.h new file mode 100644 index 000000000..142044f7f --- /dev/null +++ b/mlir/include/mlir-c/Dialect/AMDGPU.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/AMDGPU.h - C API for AMDGPU dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_AMDGPU_H +#define MLIR_C_DIALECT_AMDGPU_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_AMDGPU_H diff --git a/mlir/include/mlir-c/Dialect/NVGPU.h b/mlir/include/mlir-c/Dialect/NVGPU.h new file mode 100644 index 000000000..580d56679 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/NVGPU.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/NVGPU.h - C API for NVGPU dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_NVGPU_H +#define MLIR_C_DIALECT_NVGPU_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_NVGPU_H diff --git a/mlir/include/mlir-c/Dialect/NVVM.h b/mlir/include/mlir-c/Dialect/NVVM.h new file mode 100644 index 000000000..cf5d9301d --- /dev/null +++ b/mlir/include/mlir-c/Dialect/NVVM.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/NVVM.h - C API for NVVM dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_NVVM_H +#define MLIR_C_DIALECT_NVVM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVVM, nvvm); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_NVVM_H diff --git a/mlir/include/mlir-c/Dialect/ROCDL.h b/mlir/include/mlir-c/Dialect/ROCDL.h new file mode 100644 index 000000000..e5dbb55b5 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/ROCDL.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/ROCDL.h - C API for ROCDL dialect --*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_ROCDL_H +#define MLIR_C_DIALECT_ROCDL_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(ROCDL, rocdl); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_ROCDL_H diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp new file mode 100644 index 000000000..28efe6025 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp @@ -0,0 +1,14 @@ +//===- AMDGPU.cpp - C Interface for AMDGPU dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/AMDGPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, ml_program, + mlir::amdgpu::AMDGPUDialect) diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 4b4ab74e6..13e57a29d 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -1,3 +1,12 @@ +add_mlir_upstream_c_api_library(MLIRCAPIAMDGPU + AMDGPU.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRAMDGPUDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIArith Arith.cpp @@ -96,6 +105,34 @@ add_mlir_upstream_c_api_library(MLIRCAPIMLProgram MLIRMLProgramDialect ) +add_mlir_upstream_c_api_library(MLIRCAPINVGPU + NVGPU.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRNVGPUDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPINVVM + NVVM.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRNVVMDialect +) + +add_mlir_upstream_c_api_library(MLIRCAPIROCDL + ROCDL.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRROCDLDialect +) + + add_mlir_upstream_c_api_library(MLIRCAPISCF SCF.cpp diff --git a/mlir/lib/CAPI/Dialect/NVGPU.cpp b/mlir/lib/CAPI/Dialect/NVGPU.cpp new file mode 100644 index 000000000..02d10954a --- /dev/null +++ b/mlir/lib/CAPI/Dialect/NVGPU.cpp @@ -0,0 +1,13 @@ +//===- NVGPU.cpp - C Interface for NVGPU dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu, mlir::nvgpu::NVGPUDialect) diff --git a/mlir/lib/CAPI/Dialect/NVVM.cpp b/mlir/lib/CAPI/Dialect/NVVM.cpp new file mode 100644 index 000000000..a87581664 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/NVVM.cpp @@ -0,0 +1,13 @@ +//===- NVVM.cpp - C Interface for NVVM dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVVM.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVVM, nvvm, mlir::NVVM::NVVMDialect) diff --git a/mlir/lib/CAPI/Dialect/ROCDL.cpp b/mlir/lib/CAPI/Dialect/ROCDL.cpp new file mode 100644 index 000000000..63e2fa881 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/ROCDL.cpp @@ -0,0 +1,13 @@ +//===- ROCDL.cpp - C Interface for ROCDL dialect ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/ROCDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(ROCDL, rocdl, mlir::ROCDL::ROCDLDialect) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 0fae5dbb8..05d09eaf7 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -46,6 +46,14 @@ declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources # Dialect bindings ################################################################################ +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AMDGPUOps.td + SOURCES + dialects/amdgpu.py + DIALECT_NAME amdgpu) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -264,6 +272,30 @@ declare_mlir_dialect_python_bindings( dialects/_ml_program_ops_ext.py DIALECT_NAME ml_program) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVGPUOps.td + SOURCES + dialects/nvgpu.py + DIALECT_NAME nvgpu) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVVMOps.td + SOURCES + dialects/nvvm.py + DIALECT_NAME nvvm) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/ROCDLOps.td + SOURCES + dialects/rocdl.py + DIALECT_NAME rocdl) + declare_mlir_python_sources( MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/AMDGPUOps.td b/mlir/python/mlir/dialects/AMDGPUOps.td new file mode 100644 index 000000000..fe9371971 --- /dev/null +++ b/mlir/python/mlir/dialects/AMDGPUOps.td @@ -0,0 +1,14 @@ +//===-- AMDGPUOps.td - Entry point for AMDGPUOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_AMDGPU_OPS +#define PYTHON_BINDINGS_AMDGPU_OPS + +include "mlir/Dialect/AMDGPU/IR/AMDGPU.td" + +#endif diff --git a/mlir/python/mlir/dialects/NVGPUOps.td b/mlir/python/mlir/dialects/NVGPUOps.td new file mode 100644 index 000000000..ae54822cd --- /dev/null +++ b/mlir/python/mlir/dialects/NVGPUOps.td @@ -0,0 +1,14 @@ +//===-- NVGPUOps.td - Entry point for NVGPUOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_NVGPU_OPS +#define PYTHON_BINDINGS_NVGPU_OPS + +include "mlir/Dialect/NVGPU/IR/NVGPU.td" + +#endif diff --git a/mlir/python/mlir/dialects/NVVMOps.td b/mlir/python/mlir/dialects/NVVMOps.td new file mode 100644 index 000000000..f57d204a8 --- /dev/null +++ b/mlir/python/mlir/dialects/NVVMOps.td @@ -0,0 +1,14 @@ +//===-- NVVMOps.td - Entry point for NVVMOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_NVVM_OPS +#define PYTHON_BINDINGS_NVVM_OPS + +include "mlir/Dialect/LLVMIR/NVVMOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/ROCDLOps.td b/mlir/python/mlir/dialects/ROCDLOps.td new file mode 100644 index 000000000..fa5c9ebc3 --- /dev/null +++ b/mlir/python/mlir/dialects/ROCDLOps.td @@ -0,0 +1,14 @@ +//===-- ROCDLOps.td - Entry point for ROCDLOps -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_ROCDL_OPS +#define PYTHON_BINDINGS_ROCDL_OPS + +include "mlir/Dialect/LLVMIR/ROCDLOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py new file mode 100644 index 000000000..35283278e --- /dev/null +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._amdgpu_ops_gen import * diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py new file mode 100644 index 000000000..afd570cae --- /dev/null +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._nvgpu_ops_gen import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py new file mode 100644 index 000000000..87b2a4fd6 --- /dev/null +++ b/mlir/python/mlir/dialects/nvvm.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._nvvm_ops_gen import * diff --git a/mlir/python/mlir/dialects/rocdl.py b/mlir/python/mlir/dialects/rocdl.py new file mode 100644 index 000000000..aa47cb4b5 --- /dev/null +++ b/mlir/python/mlir/dialects/rocdl.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._rocdl_ops_gen import * From 4e2f33be5ab2e33de2165d2cc97cc561426cc3d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 11 Aug 2023 12:11:24 +0000 Subject: [PATCH 534/915] [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157706 --- .../dialects/_structured_transform_ops_ext.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 9f623efb5..675d42370 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -187,6 +187,66 @@ def __init__( ) +class MapCopyToThreadsOp: + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + class MatchOp: """Specialization for MatchOp class.""" From 05affab0663cbcc09f5b1f8882fdc63e1b47c2ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 11 Aug 2023 11:00:23 +0000 Subject: [PATCH 535/915] [mlir][linalg][transform][python] Add mix-in for BufferizeToAllocOp. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157704 --- .../dialects/_structured_transform_ops_ext.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 675d42370..d21a1b434 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -84,6 +84,40 @@ def _get_int_int_array_attr( return ArrayAttr.get(values) +class BufferizeToAllocationOp: + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[int | str | Attribute] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + ) + + class DecomposeOp: """Specialization for DecomposeOp class.""" From 3d57beb320c14ff0c7e3d3a73cef50762363d8ab Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 14 Aug 2023 08:56:24 -0700 Subject: [PATCH 536/915] Revert "[mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp." This reverts commit 4e2f33be5ab2e33de2165d2cc97cc561426cc3d3. The bot is broken: https://lab.llvm.org/buildbot/#/builders/61/builds/47577 --- .../dialects/_structured_transform_ops_ext.py | 60 ------------------- 1 file changed, 60 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index d21a1b434..6953dac63 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -221,66 +221,6 @@ def __init__( ) -class MapCopyToThreadsOp: - """Specialization for MapCopyToThreadsOp class.""" - - @overload - def __init__( - self, - forall_op_type: Type, - tiled_op_type: Type, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - def __init__( - self, - forall_op_type_or_target: Union[Operation, OpView, Type, Value], - tiled_op_type_or_none: Optional[Type] = None, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - if isinstance(forall_op_type_or_target, Type): - forall_op_type = forall_op_type_or_target - tiled_op_type = tiled_op_type_or_none - target = target_or_none - else: - forall_op_type = transform.AnyOpType.get() - tiled_op_type = transform.AnyOpType.get() - target = forall_op_type_or_target - - super().__init__( - forall_op_type, - tiled_op_type, - target, - total_num_threads=total_num_threads, - desired_bit_alignment=desired_bit_alignment, - loc=loc, - ip=ip, - ) - - class MatchOp: """Specialization for MatchOp class.""" From dfcff138329303220638df4e6ff8bbe82fcbba82 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 14 Aug 2023 09:05:13 -0700 Subject: [PATCH 537/915] Revert "[mlir][linalg][transform][python] Add mix-in for BufferizeToAllocOp." This reverts commit 05affab0663cbcc09f5b1f8882fdc63e1b47c2ec. Bot is broken https://lab.llvm.org/buildbot/#/builders/61/builds/47577 --- .../dialects/_structured_transform_ops_ext.py | 34 ------------------- 1 file changed, 34 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 6953dac63..9f623efb5 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -84,40 +84,6 @@ def _get_int_int_array_attr( return ArrayAttr.get(values) -class BufferizeToAllocationOp: - """Specialization for BufferizeToAllocationOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - memory_space: Optional[int | str | Attribute] = None, - memcpy_op: Optional[str] = None, - alloc_op: Optional[str] = None, - bufferize_destination_only: Optional[bool] = None, - loc=None, - ip=None, - ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - - if isinstance(memory_space, int): - memory_space = str(memory_space) - if isinstance(memory_space, str): - memory_space = Attribute.parse(memory_space) - - super().__init__( - allocated_buffer_type, - new_ops_type, - target, - memory_space=memory_space, - memcpy_op=memcpy_op, - alloc_op=alloc_op, - bufferize_destination_only=bufferize_destination_only, - ) - - class DecomposeOp: """Specialization for DecomposeOp class.""" From 4cbe022b26291f9b940d140660ca7fefefc7add9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 11 Aug 2023 12:11:24 +0000 Subject: [PATCH 538/915] [mlir][linalg][transform][python] Add mix-in for MapCopyToThreadsOp. Reviewed By: springerm Re-land 4e2f33be5ab2e33de2165d2cc97cc561426cc3d3 which was incorrectly reverted. Differential Revision: https://reviews.llvm.org/D157706 --- .../dialects/_structured_transform_ops_ext.py | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 9f623efb5..675d42370 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -187,6 +187,66 @@ def __init__( ) +class MapCopyToThreadsOp: + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + class MatchOp: """Specialization for MatchOp class.""" From 87672a2501f5bf464c21a8a4c69cbb6787db9bcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 11 Aug 2023 11:00:23 +0000 Subject: [PATCH 539/915] [mlir][linalg][transform][python] Add mix-in for BufferizeToAllocOp. Re-apply https://reviews.llvm.org/D157704. The original patch broke the tests on Python 3.8 and got reverted by dfcff138329303220638df4e6ff8bbe82fcbba82. This patch replaces the usage of the vertical bar operator for type hints with `Union`. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D158075 --- .../dialects/_structured_transform_ops_ext.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 675d42370..e34451af4 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -84,6 +84,40 @@ def _get_int_int_array_attr( return ArrayAttr.get(values) +class BufferizeToAllocationOp: + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[Union[int, str, Attribute]] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + ) + + class DecomposeOp: """Specialization for DecomposeOp class.""" From 43d038aae6773792b96e293e9f897ddbe20b21ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 11 Aug 2023 17:34:50 +0000 Subject: [PATCH 540/915] [mlir][linalg][transform][python] Add mix-in for MaskedVectorize. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157735 --- .../dialects/_structured_transform_ops_ext.py | 99 +++++++++++++++++-- 1 file changed, 92 insertions(+), 7 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index e34451af4..de5161eb1 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -11,19 +11,67 @@ from typing import List, Optional, Sequence, Tuple, Union, overload +StaticIntLike = Union[int, IntegerAttr] +ValueLike = Union[Operation, OpView, Value] +MixedInt = Union[StaticIntLike, ValueLike] + IntOrAttrList = Sequence[Union[IntegerAttr, int]] OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] BoolOrAttrList = Sequence[Union[BoolAttr, bool]] OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] -MixedValues = Union[ - Sequence[Union[int, IntegerAttr, Operation, Value, OpView]], - ArrayAttr, - Operation, - Value, - OpView, -] +MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: Union[DynamicIndexList, ArrayAttr], +) -> tuple[list[ValueLike], list[int] | ArrayAttr, list[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices # Dispatches `MixedValues` that all represents integers in various forms into @@ -281,6 +329,43 @@ def __init__( ) +class MaskedVectorizeOp: + """Specialization for MaskedVectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + vector_sizes: Union[DynamicIndexList, ArrayAttr], + *, + vectorize_nd_extract: Optional[bool] = None, + scalable_sizes: OptionalBoolList = None, + static_vector_sizes: OptionalIntList = None, + loc=None, + ip=None, + ): + if scalable_sizes is None and static_vector_sizes is None: + ( + dynamic_vector_sizes, + static_vector_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(vector_sizes) + elif scalable_sizes is None or static_vector_sizes is None: + raise TypeError( + "'scalable_sizes' and 'static_vector_sizes' must either both " + "be given explicitly or both be given as part of 'vector_sizes'." + ) + else: + dynamic_vector_sizes = vector_sizes + + super().__init__( + target, + vector_sizes=dynamic_vector_sizes, + static_vector_sizes=static_vector_sizes, + scalable_sizes=scalable_sizes, + vectorize_nd_extract=vectorize_nd_extract, + ) + + class MatchOp: """Specialization for MatchOp class.""" From 5d68bbd66efe19231d35ca7362867f48d8d16c0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 16 Aug 2023 16:24:29 +0000 Subject: [PATCH 541/915] [mlir][linalg][transform][python] Fix mix-in for MaskedVectorize. Fix forward bug in 43d038aae6773792b96e293e9f897ddbe20b21ae, which uses the vertical bar operator for type hints, which is only supported by Python 3.10 and later, and thus breaks the builds on Python 3.8. --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index de5161eb1..b822ba6d7 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -28,7 +28,7 @@ def _dispatch_dynamic_index_list( indices: Union[DynamicIndexList, ArrayAttr], -) -> tuple[list[ValueLike], list[int] | ArrayAttr, list[bool]]: +) -> tuple[list[ValueLike], Union[list[int], ArrayAttr], list[bool]]: """Dispatches a list of indices to the appropriate form. This is similar to the custom `DynamicIndexList` directive upstream: From 7fb8207a1c30f4dc032cc8818c9c2bacfe71dd19 Mon Sep 17 00:00:00 2001 From: Rahul Kayaith Date: Wed, 16 Aug 2023 16:16:47 -0400 Subject: [PATCH 542/915] [mlir][linalg][transform][python] Fix type hints Older python versions (e.g. 3.8) don't accept `tuple[...]` etc. in type hints. --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index b822ba6d7..b63652957 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -28,7 +28,7 @@ def _dispatch_dynamic_index_list( indices: Union[DynamicIndexList, ArrayAttr], -) -> tuple[list[ValueLike], Union[list[int], ArrayAttr], list[bool]]: +) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: """Dispatches a list of indices to the appropriate form. This is similar to the custom `DynamicIndexList` directive upstream: From 45a5fff2c1abb1b68f1cd159f0ae28598a4b3de6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Sat, 12 Aug 2023 17:05:40 +0000 Subject: [PATCH 543/915] [mlir][linalg][transform][python] Improve mix-in for PadOp. In particular: * Fix and extend the support for constructing possibly nested ArrayAttrs from lists of Python ints. This can probably be generalized further and used in many more places. * Add arguments for `pad_to_multiple_of` and `copy_back_op`. * Format with black and reorder (keyword-only) arguments to match tablegen and (`*_gen.py`) order. * Extend tests for new features. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157789 --- .../dialects/_structured_transform_ops_ext.py | 147 +++++++++++------- 1 file changed, 95 insertions(+), 52 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index b63652957..48dee5f80 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -107,28 +107,60 @@ def _dispatch_mixed_values( return (dynamic_values, packed_values, static_values) -def _get_int_int_array_attr( +def _get_value_or_attribute_value( + value_or_attr: Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: Union[Sequence[any], ArrayAttr] +) -> Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: + if values is None: + return ArrayAttr.get([]) + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] ) -> ArrayAttr: - """Creates an array attribute containing array attributes of integers. + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - If the operand is already an array attribute, forwards it. Otherwise treats - the operand as a list of attributes or integers, potentially interpserced, to - create a new array-of-array attribute. Expects the thread-local MLIR context - to have been set by the context manager. + The input has to be a collection of collection of integers, where any + Python Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. """ if values is None: return ArrayAttr.get([]) - if isinstance(values, ArrayAttr): - return values - if isinstance(values, list): - values = [ - ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in value] - ) - for value in values - ] + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. return ArrayAttr.get(values) @@ -455,44 +487,55 @@ def __init__( class PadOp: - """Specialization for PadOp class.""" + """Specialization for PadOp class.""" - def __init__( - self, - target: Union[Operation, Value], - *, - padding_values: Optional[ - Optional[Union[ArrayAttr, Sequence[Attribute]]] - ] = None, - padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[ - Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] - ] = None, - loc=None, - ip=None, - ): - if transpose_paddings is None: - transpose_paddings = [] - if pack_paddings is None: - pack_paddings = [] - if padding_dimensions is None: - padding_dimensions = [] - if padding_values is None: - padding_values = [] - pdl_operation_type = pdl.OperationType.get() - transpose_paddings_attr = _get_int_int_array_attr(transpose_paddings) - super().__init__( - pdl_operation_type, - pdl_operation_type, - _get_op_result_or_value(target), - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings_attr, - loc=loc, - ip=ip, - ) + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + padding_values: Optional[ + Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]] + ] = None, + padding_dimensions: OptionalIntList = None, + pad_to_multiple_of: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + copy_back_op: Optional[Union[str, StringAttr]] = None, + loc=None, + ip=None, + ): + if padding_values is None: + padding_values = [] + if padding_dimensions is None: + padding_dimensions = [] + if pad_to_multiple_of is None: + pad_to_multiple_of = [] + if pack_paddings is None: + pack_paddings = [] + if transpose_paddings is None: + transpose_paddings = [] + + padding_dimensions = _get_int_array_attr(padding_dimensions) + pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of) + pack_paddings = _get_int_array_attr(pack_paddings) + transpose_paddings = _get_int_array_array_attr(transpose_paddings) + + pdl_operation_type = pdl.OperationType.get() + super().__init__( + pdl_operation_type, + pdl_operation_type, + target, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pad_to_multiple_of=pad_to_multiple_of, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings, + copy_back_op=copy_back_op, + loc=loc, + ip=ip, + ) class ScalarizeOp: From b5c50da7e81c021f93a67b42cb5dafc2d01ddcdc Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 22 Aug 2023 00:38:58 +0000 Subject: [PATCH 544/915] [mlir] Disentangle dialect and extension registrations. This revision avoids the registration of dialect extensions in Pass::getDependentDialects. Such registration of extensions can be dangerous because `DialectRegistry::isSubsetOf` is always guaranteed to return false for extensions (i.e. there is no mechanism to track whether a lambda is already in the list of already registered extensions). When the context is already in a multi-threaded mode, this is guaranteed to assert. Arguably a more structured registration mechanism for extensions with a unique ExtensionID could be envisioned in the future. In the process of cleaning this up, multiple usage inconsistencies surfaced around the registration of translation extensions that this revision also cleans up. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D157703 --- mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp index b63899bd5..c1c4a418b 100644 --- a/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp +++ b/mlir/lib/CAPI/RegisterEverything/RegisterEverything.cpp @@ -9,9 +9,11 @@ #include "mlir-c/RegisterEverything.h" #include "mlir/CAPI/IR.h" +#include "mlir/IR/MLIRContext.h" #include "mlir/InitAllDialects.h" #include "mlir/InitAllExtensions.h" #include "mlir/InitAllPasses.h" +#include "mlir/Target/LLVMIR/Dialect/All.h" #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" @@ -22,8 +24,9 @@ void mlirRegisterAllDialects(MlirDialectRegistry registry) { void mlirRegisterAllLLVMTranslations(MlirContext context) { auto &ctx = *unwrap(context); - mlir::registerBuiltinDialectTranslation(ctx); - mlir::registerLLVMDialectTranslation(ctx); + mlir::DialectRegistry registry; + mlir::registerAllToLLVMIRTranslations(registry); + ctx.appendDialectRegistry(registry); } void mlirRegisterAllPasses() { mlir::registerAllPasses(); } From 4d4b6df3c01dff179de0f0c3548b25070a6ed556 Mon Sep 17 00:00:00 2001 From: Yinying Li Date: Tue, 22 Aug 2023 23:48:03 +0000 Subject: [PATCH 545/915] [mlir][sparse] Changed sparsity properties to use _ instead of - Example: compressed-no -> compressed_no Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D158567 --- .../Bindings/Python/DialectSparseTensor.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index d03088341..70805005f 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -21,19 +21,19 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) .value("compressed24", MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR) .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) - .value("compressed-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) - .value("compressed-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) - .value("compressed-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) + .value("compressed_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) + .value("compressed_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) + .value("compressed_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) - .value("singleton-nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) - .value("singleton-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) - .value("singleton-nu-no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) - .value("compressed-hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI) - .value("compressed-hi-nu", + .value("singleton_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) + .value("singleton_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) + .value("singleton_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) + .value("compressed_hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI) + .value("compressed_hi_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU) - .value("compressed-hi-no", + .value("compressed_hi_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO) - .value("compressed-hi-nu-no", + .value("compressed_hi_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", From af3539156a6d514c3123a24636f58524fa536938 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 23 Aug 2023 13:27:08 -0500 Subject: [PATCH 546/915] [mlir][python bindings] generate all the enums This PR implements python enum bindings for *all* the enums - this includes `I*Attrs` (including positional/bit) and `Dialect/EnumAttr`. There are a few parts to this: 1. CMake: a small addition to `declare_mlir_dialect_python_bindings` and `declare_mlir_dialect_extension_python_bindings` to generate the enum, a boolean arg `GEN_ENUM_BINDINGS` to make it opt-in (even though it works for basically all of the dialects), and an optional `GEN_ENUM_BINDINGS_TD_FILE` for handling corner cases. 2. EnumPythonBindingGen.cpp: there are two weedy aspects here that took investigation: 1. If an enum attribute is not a `Dialect/EnumAttr` then the `EnumAttrInfo` record is canonical, as far as both the cases of the enum **and the `AttrDefName`**. On the otherhand, if an enum is a `Dialect/EnumAttr` then the `EnumAttr` record has the correct `AttrDefName` ("load bearing", i.e., populates `ods.ir.AttributeBuilder('')`) but its `enum` field contains the cases, which is an instance of `EnumAttrInfo`. The solution is to generate an one enum class for both `Dialect/EnumAttr` and "independent" `EnumAttrInfo` but to make that class interopable with two builder registrations that both do the right thing (see next sub-bullet). 2. Because we don't have a good connection to cpp `EnumAttr`, i.e., only the `enum class` getters are exposed (like `DimensionAttr::get(Dimension value)`), we have to resort to parsing e.g., `Attribute.parse(f'#gpu')`. This means that the set of supported `assemblyFormat`s (for the enum) is fixed at compile of MLIR (currently 2, the only 2 I saw). There might be some things that could be done here but they would require quite a bit more C API work to support generically (e.g., casting ints to enum cases and binding all the getters or going generically through the `symbolize*` methods, like `symbolizeDimension(uint32_t)` or `symbolizeDimension(StringRef)`). A few small changes: 1. In addition, since this patch registers default builders for attributes where people might've had their own builders already written, I added a `replace` param to `AttributeBuilder.insert` (`False` by default). 2. `makePythonEnumCaseName` can't handle all the different ways in which people write their enum cases, e.g., `llvm.CConv.Intel_OCL_BI`, which gets turned into `INTEL_O_C_L_B_I` (because `llvm::convertToSnakeFromCamelCase` doesn't look for runs of caps). So I dropped it. On the otherhand regularization does need to done because some enums have `None` as a case (and others might have other python keywords). 3. I turned on `llvm` dialect generation here in order to test `nvvm.WGMMAScaleIn`, which is an enum with [[ https://github.com/llvm/llvm-project/blob/5d68bbd66efe19231d35ca7362867f48d8d16c0a/mlir/include/mlir/IR/EnumAttr.td#L22-L25 | no explicit discriminator ]] for the `neg` case. Note, dialects that didn't get a `GEN_ENUM_BINDINGS` don't have any enums to generate. Let me know if I should add more tests (the three trivial ones I added exercise both the supported `assemblyFormat`s and `replace=True`). Reviewed By: stellaraccident Differential Revision: https://reviews.llvm.org/D157934 --- mlir/lib/Bindings/Python/Globals.h | 5 +- mlir/lib/Bindings/Python/IRCore.cpp | 12 ++- mlir/lib/Bindings/Python/IRModule.cpp | 8 +- mlir/python/CMakeLists.txt | 86 ++++++++++--------- mlir/python/mlir/dialects/LLVMOps.td | 14 +++ mlir/python/mlir/dialects/amdgpu.py | 1 + mlir/python/mlir/dialects/arith.py | 1 + mlir/python/mlir/dialects/bufferization.py | 1 + mlir/python/mlir/dialects/gpu/__init__.py | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 1 + mlir/python/mlir/dialects/llvm.py | 6 ++ mlir/python/mlir/dialects/nvgpu.py | 1 + mlir/python/mlir/dialects/nvvm.py | 1 + mlir/python/mlir/dialects/sparse_tensor.py | 1 + .../mlir/dialects/transform/bufferization.py | 1 - .../mlir/dialects/transform/structured.py | 1 + mlir/python/mlir/dialects/vector.py | 1 + mlir/python/mlir/ir.py | 4 +- 18 files changed, 94 insertions(+), 52 deletions(-) create mode 100644 mlir/python/mlir/dialects/LLVMOps.td create mode 100644 mlir/python/mlir/dialects/llvm.py diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 0fc7614cc..97cd70089 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -58,10 +58,11 @@ class PyGlobals { void loadDialectModule(llvm::StringRef dialectNamespace); /// Adds a user-friendly Attribute builder. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc); + pybind11::function pyFunc, + bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e1b8d296a..b06937bc2 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -242,19 +242,23 @@ struct PyAttrBuilderMap { static py::function dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(); + throw py::key_error(attributeKind); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func) { - PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func)); + py::function func, bool replace) { + PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), + replace); } static void bind(py::module &m) { py::class_(m, "AttrBuilder", py::module_local()) .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed); + .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, + "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, + "Register an attribute builder for building MLIR " + "attributes from python values."); } }; diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index d9a66bce0..2cc66277a 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -63,11 +63,13 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc) { + py::function pyFunc, bool replace) { py::object &found = attributeBuilderMap[attributeKind]; - if (found) { + if (found && !found.is_none() && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + - attributeKind + "' is already registered") + attributeKind + + "' is already registered with func: " + + py::str(found).operator std::string()) .str()); } found = std::move(pyFunc); diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 05d09eaf7..225da778c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,8 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/AMDGPUOps.td SOURCES dialects/amdgpu.py - DIALECT_NAME amdgpu) + DIALECT_NAME amdgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -68,7 +69,10 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/bufferization.py dialects/_bufferization_ops_ext.py - DIALECT_NAME bufferization) + DIALECT_NAME bufferization + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td" +) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -109,7 +113,8 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUOps.td SOURCES_GLOB dialects/gpu/*.py - DIALECT_NAME gpu) + DIALECT_NAME gpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -120,7 +125,17 @@ declare_mlir_dialect_python_bindings( SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg - DEPENDS LinalgOdsGen) + DEPENDS LinalgOdsGen + GEN_ENUM_BINDINGS) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/LLVMOps.td + SOURCES + dialects/llvm.py + DIALECT_NAME llvm + GEN_ENUM_BINDINGS) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -140,16 +155,10 @@ declare_mlir_dialect_python_bindings( dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi - DIALECT_NAME transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/TransformOps.td") -mlir_tablegen("dialects/_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.transform - SOURCES "dialects/_transform_enum_gen.py") + DIALECT_NAME transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td" +) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -161,15 +170,6 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME bufferization_transform) -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/BufferizationTransformOps.td") -mlir_tablegen("dialects/_bufferization_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRBufferizationTransformDialectPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.bufferization_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.bufferization_transform - SOURCES "dialects/_bufferization_transform_enum_gen.py") - declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -208,7 +208,10 @@ declare_mlir_dialect_extension_python_bindings( dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform - EXTENSION_NAME structured_transform) + EXTENSION_NAME structured_transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" +) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -227,16 +230,10 @@ declare_mlir_dialect_extension_python_bindings( SOURCES dialects/transform/vector.py DIALECT_NAME transform - EXTENSION_NAME vector_transform) - -set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td") -mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings) -add_public_tablegen_target(MLIRVectorTransformPyEnumGen) -declare_mlir_python_sources( - MLIRPythonSources.Dialects.vector_transform.enum_gen - ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}" - ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform - SOURCES "dialects/_vector_transform_enum_gen.py" ) + EXTENSION_NAME vector_transform + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" +) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -252,7 +249,8 @@ declare_mlir_dialect_python_bindings( SOURCES dialects/arith.py dialects/_arith_ops_ext.py - DIALECT_NAME arith) + DIALECT_NAME arith + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -278,7 +276,8 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/NVGPUOps.td SOURCES dialects/nvgpu.py - DIALECT_NAME nvgpu) + DIALECT_NAME nvgpu + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -286,7 +285,8 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/NVVMOps.td SOURCES dialects/nvvm.py - DIALECT_NAME nvvm) + DIALECT_NAME nvvm + GEN_ENUM_BINDINGS) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -300,6 +300,7 @@ declare_mlir_python_sources( MLIRPythonSources.Dialects.quant ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS SOURCES dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) @@ -335,7 +336,10 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SparseTensorOps.td SOURCES dialects/sparse_tensor.py - DIALECT_NAME sparse_tensor) + DIALECT_NAME sparse_tensor + GEN_ENUM_BINDINGS_TD_FILE + "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" +) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -351,14 +355,16 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TosaOps.td SOURCES dialects/tosa.py - DIALECT_NAME tosa) + DIALECT_NAME tosa +) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/VectorOps.td SOURCES dialects/vector.py - DIALECT_NAME vector) + DIALECT_NAME vector + GEN_ENUM_BINDINGS) ################################################################################ # Python extensions. diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td new file mode 100644 index 000000000..dcf2f4245 --- /dev/null +++ b/mlir/python/mlir/dialects/LLVMOps.td @@ -0,0 +1,14 @@ +//===-- LlvmOps.td - Entry point for llvm bind ---------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_LLVM_OPS +#define PYTHON_BINDINGS_LLVM_OPS + +include "mlir/Dialect/LLVMIR/LLVMOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/amdgpu.py b/mlir/python/mlir/dialects/amdgpu.py index 35283278e..43d905d0c 100644 --- a/mlir/python/mlir/dialects/amdgpu.py +++ b/mlir/python/mlir/dialects/amdgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._amdgpu_ops_gen import * +from ._amdgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 77318b286..fb13beb63 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_enum_gen import * diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py index 2121122f1..759b6aa24 100644 --- a/mlir/python/mlir/dialects/bufferization.py +++ b/mlir/python/mlir/dialects/bufferization.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._bufferization_ops_gen import * +from ._bufferization_enum_gen import * diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py index 67bf7bd85..033386b0f 100644 --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_ops_gen import * +from .._gpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index eadb8420c..1353870ec 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -9,6 +9,7 @@ # definitions following these steps: # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * +from .._linalg_enum_gen import * # These are the ground truth functions defined as: # ``` diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py new file mode 100644 index 000000000..77025438c --- /dev/null +++ b/mlir/python/mlir/dialects/llvm.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._llvm_ops_gen import * +from ._llvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py index afd570cae..2f6993b76 100644 --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvgpu_ops_gen import * +from ._nvgpu_enum_gen import * diff --git a/mlir/python/mlir/dialects/nvvm.py b/mlir/python/mlir/dialects/nvvm.py index 87b2a4fd6..9477de39c 100644 --- a/mlir/python/mlir/dialects/nvvm.py +++ b/mlir/python/mlir/dialects/nvvm.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._nvvm_ops_gen import * +from ._nvvm_enum_gen import * diff --git a/mlir/python/mlir/dialects/sparse_tensor.py b/mlir/python/mlir/dialects/sparse_tensor.py index 769418e04..209ecc95f 100644 --- a/mlir/python/mlir/dialects/sparse_tensor.py +++ b/mlir/python/mlir/dialects/sparse_tensor.py @@ -3,5 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._sparse_tensor_ops_gen import * +from ._sparse_tensor_enum_gen import * from .._mlir_libs._mlirDialectsSparseTensor import * from .._mlir_libs import _mlirSparseTensorPasses as _cextSparseTensorPasses diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py index 1891bc0e1..eb77b746c 100644 --- a/mlir/python/mlir/dialects/transform/bufferization.py +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -2,5 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._bufferization_transform_enum_gen import * from .._bufferization_transform_ops_gen import * diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index b8ee48c42..cb3812301 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_enum_gen import * diff --git a/mlir/python/mlir/dialects/vector.py b/mlir/python/mlir/dialects/vector.py index 610c0b204..7384e9a5a 100644 --- a/mlir/python/mlir/dialects/vector.py +++ b/mlir/python/mlir/dialects/vector.py @@ -3,3 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._vector_ops_gen import * +from ._vector_enum_gen import * diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index e36736f29..36c49fe6f 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -8,9 +8,9 @@ # Convenience decorator for registering user-friendly Attribute builders. -def register_attribute_builder(kind): +def register_attribute_builder(kind, replace=False): def decorator_builder(func): - AttrBuilder.insert(kind, func) + AttrBuilder.insert(kind, func, replace=replace) return func return decorator_builder From ba95c608134408fdc1267708423d72f66f6f3f8e Mon Sep 17 00:00:00 2001 From: max Date: Wed, 23 Aug 2023 17:41:04 -0500 Subject: [PATCH 547/915] [mlir][python bindings] turn on openmp Just as in https://reviews.llvm.org/D157820, dialect registration is independent of any vendor specific libs having been linked/built/etc. Reviewed By: rkayaith Differential Revision: https://reviews.llvm.org/D158670 --- mlir/include/mlir-c/Dialect/OpenMP.h | 25 +++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/OpenMP.cpp | 16 ++++++++++++++++ mlir/python/CMakeLists.txt | 8 ++++++++ mlir/python/mlir/dialects/OpenMPOps.td | 14 ++++++++++++++ mlir/python/mlir/dialects/openmp.py | 5 +++++ 6 files changed, 77 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/OpenMP.h create mode 100644 mlir/lib/CAPI/Dialect/OpenMP.cpp create mode 100644 mlir/python/mlir/dialects/OpenMPOps.td create mode 100644 mlir/python/mlir/dialects/openmp.py diff --git a/mlir/include/mlir-c/Dialect/OpenMP.h b/mlir/include/mlir-c/Dialect/OpenMP.h new file mode 100644 index 000000000..719ed702a --- /dev/null +++ b/mlir/include/mlir-c/Dialect/OpenMP.h @@ -0,0 +1,25 @@ +//===-- mlir-c/Dialect/OpenMP.h - C API for OpenMP Dialect --------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_OPENM_H +#define MLIR_C_DIALECT_OPENM_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(OpenMP, omp); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_OPENM_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 13e57a29d..d815eba48 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIQuant MLIRQuantDialect ) +add_mlir_upstream_c_api_library(MLIRCAPIOpenMP + OpenMP.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIROpenMPDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIPDL PDL.cpp diff --git a/mlir/lib/CAPI/Dialect/OpenMP.cpp b/mlir/lib/CAPI/Dialect/OpenMP.cpp new file mode 100644 index 000000000..3ffa57ab5 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/OpenMP.cpp @@ -0,0 +1,16 @@ +//===- OPENMP.cpp - C Interface for OPENMP dialect +//------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/OpenMP.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/OpenMP/OpenMPDialect.h" + +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(OpenMP, omp, omp::OpenMPDialect) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 225da778c..5d2f233ca 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -315,6 +315,14 @@ declare_mlir_dialect_python_bindings( _mlir_libs/_mlir/dialects/pdl.pyi DIALECT_NAME pdl) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/OpenMPOps.td + SOURCES + dialects/openmp.py + DIALECT_NAME omp) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/OpenMPOps.td b/mlir/python/mlir/dialects/OpenMPOps.td new file mode 100644 index 000000000..b91179b0d --- /dev/null +++ b/mlir/python/mlir/dialects/OpenMPOps.td @@ -0,0 +1,14 @@ +//===-- OpenMPOps.td - Entry point for OpenMPOps bind ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_OPENMP_OPS +#define PYTHON_BINDINGS_OPENMP_OPS + +include "mlir/Dialect/OpenMP/OpenMPOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/openmp.py b/mlir/python/mlir/dialects/openmp.py new file mode 100644 index 000000000..604f0bd03 --- /dev/null +++ b/mlir/python/mlir/dialects/openmp.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._omp_ops_gen import * From 98e803a6e66b73e9f7202e71ca5ccef77a8cc879 Mon Sep 17 00:00:00 2001 From: Yijia Gu Date: Wed, 23 Aug 2023 17:50:29 -0700 Subject: [PATCH 548/915] update bazel for python binding --- .../mlir/dialects/BufferizationEnums.td | 14 +++++++++++++ .../LinalgStructuredTransformEnums.td | 20 +++++++++++++++++++ .../mlir/dialects/SparseTensorAttrDefs.td | 14 +++++++++++++ mlir/python/mlir/dialects/TransformAttrs.td | 14 +++++++++++++ .../mlir/dialects/VectorTransformsBase.td | 19 ++++++++++++++++++ 5 files changed, 81 insertions(+) create mode 100644 mlir/python/mlir/dialects/BufferizationEnums.td create mode 100644 mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td create mode 100644 mlir/python/mlir/dialects/SparseTensorAttrDefs.td create mode 100644 mlir/python/mlir/dialects/TransformAttrs.td create mode 100644 mlir/python/mlir/dialects/VectorTransformsBase.td diff --git a/mlir/python/mlir/dialects/BufferizationEnums.td b/mlir/python/mlir/dialects/BufferizationEnums.td new file mode 100644 index 000000000..dc67b12ff --- /dev/null +++ b/mlir/python/mlir/dialects/BufferizationEnums.td @@ -0,0 +1,14 @@ +//===-- BufferizationEnums.td - Entry point for BufferizationEnums bindings ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_BUFFERIZATION_ENUMS +#define PYTHON_BINDINGS_BUFFERIZATION_ENUMS + +include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td" + +#endif // PYTHON_BINDINGS_BUFFERIZATION_ENUMS \ No newline at end of file diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td new file mode 100644 index 000000000..ecc8d91ab --- /dev/null +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td @@ -0,0 +1,20 @@ +//===-- LinalgStructuredTransformEnums.td --------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the structured transform ops +// provided by Linalg (and other dialects). +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS +#define PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS + +include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" + +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS \ No newline at end of file diff --git a/mlir/python/mlir/dialects/SparseTensorAttrDefs.td b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td new file mode 100644 index 000000000..95920cca9 --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td @@ -0,0 +1,14 @@ +//===-- SparseTensorAttrDefs.td - Entry point for bindings ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS +#define PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS + +include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" + +#endif // PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS \ No newline at end of file diff --git a/mlir/python/mlir/dialects/TransformAttrs.td b/mlir/python/mlir/dialects/TransformAttrs.td new file mode 100644 index 000000000..4314ac14d --- /dev/null +++ b/mlir/python/mlir/dialects/TransformAttrs.td @@ -0,0 +1,14 @@ +//===-- TransformAttrs.td - Transform attrs bind entry point ---*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_ATTRS +#define PYTHON_BINDINGS_TRANSFORM_ATTRS + +include "mlir/Dialect/Transform/IR/TransformAttrs.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_ATTRS \ No newline at end of file diff --git a/mlir/python/mlir/dialects/VectorTransformsBase.td b/mlir/python/mlir/dialects/VectorTransformsBase.td new file mode 100644 index 000000000..ced470300 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorTransformsBase.td @@ -0,0 +1,19 @@ +//===-- VectorTransformsBase.td ------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the vector transform ops. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_VECTORTRANSFORMBASE +#define PYTHON_BINDINGS_VECTORTRANSFORMBASE + +include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" + +#endif // PYTHON_BINDINGS_VECTORTRANSFORMBASE \ No newline at end of file From ef8b4542db3a483e521580d1db236b620e3cd338 Mon Sep 17 00:00:00 2001 From: Yijia Gu Date: Wed, 23 Aug 2023 17:59:36 -0700 Subject: [PATCH 549/915] add empty line in the end of the td files --- mlir/python/mlir/dialects/BufferizationEnums.td | 2 +- mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td | 2 +- mlir/python/mlir/dialects/SparseTensorAttrDefs.td | 2 +- mlir/python/mlir/dialects/TransformAttrs.td | 2 +- mlir/python/mlir/dialects/VectorTransformsBase.td | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/dialects/BufferizationEnums.td b/mlir/python/mlir/dialects/BufferizationEnums.td index dc67b12ff..f676ce082 100644 --- a/mlir/python/mlir/dialects/BufferizationEnums.td +++ b/mlir/python/mlir/dialects/BufferizationEnums.td @@ -11,4 +11,4 @@ include "mlir/Dialect/Bufferization/IR/BufferizationEnums.td" -#endif // PYTHON_BINDINGS_BUFFERIZATION_ENUMS \ No newline at end of file +#endif // PYTHON_BINDINGS_BUFFERIZATION_ENUMS diff --git a/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td index ecc8d91ab..e86c9b7dd 100644 --- a/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td +++ b/mlir/python/mlir/dialects/LinalgStructuredTransformEnums.td @@ -17,4 +17,4 @@ include "mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" -#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS \ No newline at end of file +#endif // PYTHON_BINDINGS_LINALG_STRUCTURED_TRANSFORM_ENUMS diff --git a/mlir/python/mlir/dialects/SparseTensorAttrDefs.td b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td index 95920cca9..5a86f55df 100644 --- a/mlir/python/mlir/dialects/SparseTensorAttrDefs.td +++ b/mlir/python/mlir/dialects/SparseTensorAttrDefs.td @@ -11,4 +11,4 @@ include "mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" -#endif // PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS \ No newline at end of file +#endif // PYTHON_BINDINGS_SPARSE_TENSOR_ATTR_DEFS diff --git a/mlir/python/mlir/dialects/TransformAttrs.td b/mlir/python/mlir/dialects/TransformAttrs.td index 4314ac14d..451118a5d 100644 --- a/mlir/python/mlir/dialects/TransformAttrs.td +++ b/mlir/python/mlir/dialects/TransformAttrs.td @@ -11,4 +11,4 @@ include "mlir/Dialect/Transform/IR/TransformAttrs.td" -#endif // PYTHON_BINDINGS_TRANSFORM_ATTRS \ No newline at end of file +#endif // PYTHON_BINDINGS_TRANSFORM_ATTRS diff --git a/mlir/python/mlir/dialects/VectorTransformsBase.td b/mlir/python/mlir/dialects/VectorTransformsBase.td index ced470300..acb4aeced 100644 --- a/mlir/python/mlir/dialects/VectorTransformsBase.td +++ b/mlir/python/mlir/dialects/VectorTransformsBase.td @@ -16,4 +16,4 @@ include "mlir/Dialect/Vector/Transforms/VectorTransformsBase.td" -#endif // PYTHON_BINDINGS_VECTORTRANSFORMBASE \ No newline at end of file +#endif // PYTHON_BINDINGS_VECTORTRANSFORMBASE From 2daeaff5071f6af6982b4fa01b03fe5a38892f23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 28 Aug 2023 08:04:28 +0000 Subject: [PATCH 550/915] [mlir][linalg][transform][python] Extend mix-in for Vectorize Extends the existing mix-in for VectorizeOp with support for the missing unit attributes. Also fixes the unintuitive implementation where `structured.VectorizeOp(target=target, vectorize_padding=False)` still resulted in the creation of the UnitAttr `vectorize_padding`. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D158726 --- .../mlir/dialects/_structured_transform_ops_ext.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 48dee5f80..9e039ffa6 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -783,16 +783,20 @@ def __init__( self, target: Union[Operation, Value], *, - vectorize_padding: Union[bool, BoolAttr] = False, + disable_multi_reduction_to_contract_patterns: bool = False, + disable_transfer_permutation_map_lowering_patterns: bool = False, + vectorize_nd_extract: bool = False, + vectorize_padding: bool = False, loc=None, ip=None, ): pdl_operation_type = pdl.OperationType.get() - if isinstance(vectorize_padding, bool): - vectorize_padding = UnitAttr.get() super().__init__( pdl_operation_type, _get_op_result_or_value(target), + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, vectorize_padding=vectorize_padding, loc=loc, ip=ip, From f21952657fce8e6ffa9f2ef91e1545c6a1277b07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 28 Aug 2023 09:32:49 +0000 Subject: [PATCH 551/915] [mlir][python] Make DenseBoolArrayAttr.get work with list of bools. This patch makes the getter function of `DenseBoolArrayAttr` work more intuitively. Until now, it was implemented with a `std::vector` argument, which works in the typical situation where you call the pybind function with a list of Python bools (like `[True, False]`). However, it does *not* work if the elements of the list have to be cast to Bool before (and that is the default behavior for lists of all other types). The patch thus changes the signature to `std::vector`, which helps pybind to make the function behave as expected for bools. The tests now also contain a case where such a cast is happening. This also makes the conversion of `DenseBoolArrayAttr` back to Python more intuitive: instead of converting to `0` and `1`, the elements are now converted to `False` and `True`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158973 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 25 ++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 75d743f3a..50cfc0624 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -162,9 +162,7 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { c.def_static( "get", [](const std::vector &values, DefaultingPyMlirContext ctx) { - MlirAttribute attr = - DerivedT::getAttribute(ctx->get(), values.size(), values.data()); - return DerivedT(ctx->getRef(), attr); + return getAttribute(values, ctx->getRef()); }, py::arg("values"), py::arg("context") = py::none(), "Gets a uniqued dense array attribute"); @@ -187,16 +185,29 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { values.push_back(arr.getItem(i)); for (py::handle attr : extras) values.push_back(pyTryCast(attr)); - MlirAttribute attr = DerivedT::getAttribute(arr.getContext()->get(), - values.size(), values.data()); - return DerivedT(arr.getContext(), attr); + return getAttribute(values, arr.getContext()); }); } + +private: + static DerivedT getAttribute(const std::vector &values, + PyMlirContextRef ctx) { + if constexpr (std::is_same_v) { + std::vector intValues(values.begin(), values.end()); + MlirAttribute attr = DerivedT::getAttribute(ctx->get(), intValues.size(), + intValues.data()); + return DerivedT(ctx, attr); + } else { + MlirAttribute attr = + DerivedT::getAttribute(ctx->get(), values.size(), values.data()); + return DerivedT(ctx, attr); + } + } }; /// Instantiate the python dense array classes. struct PyDenseBoolArrayAttribute - : public PyDenseArrayAttribute { + : public PyDenseArrayAttribute { static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADenseBoolArray; static constexpr auto getAttribute = mlirDenseBoolArrayGet; static constexpr auto getElement = mlirDenseBoolArrayGetElement; From 43211c28cd1cca853357c3b7a3a130f2d3209d04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 25 Aug 2023 14:53:18 +0000 Subject: [PATCH 552/915] [mlir][python] Add __{bool,float,int,str}__ to bindings of attributes. This allows to use Python's `bool(.)`, `float(.)`, `int(.)`, and `str(.)` to convert pybound attributes to the corresponding native Python types. In particular, pybind11 uses these functions to automatically cast objects to the corresponding primitive types wherever they are required by pybound functions, e.g., arguments are converted to Python's `int` if the C++ signature requires a C++ `int`. With this patch, pybound attributes can by used wherever the corresponding native types are expected. New tests show-case this behavior in the constructors of `Dense*ArrayAttr`. Note that this changes the output of Python's `str` on `StringAttr` from `"hello"` to `hello`. Arguably, this is still in line with `str`s goal of producing a readable interpretation of the value, even if it is now not unambiously a string anymore (`print(ir.Attribute.parse('"42"'))` now outputs `42`). However, this is consistent with instances of Python's `str` (`print("42")` outputs `42`), and `repr` still provides an unambigous representation if one is required. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158974 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 60 ++++++++++++----------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 50cfc0624..6531d6276 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -389,12 +389,10 @@ class PyFloatAttribute : public PyConcreteAttribute { }, py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly( - "value", - [](PyFloatAttribute &self) { - return mlirFloatAttrGetValueDouble(self); - }, - "Returns the value of the float point attribute"); + c.def_property_readonly("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); + c.def("__float__", mlirFloatAttrGetValueDouble, + "Converts the value of the float attribute to a Python float"); } }; @@ -414,22 +412,25 @@ class PyIntegerAttribute : public PyConcreteAttribute { }, py::arg("type"), py::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyIntegerAttribute &self) -> py::int_ { - MlirType type = mlirAttributeGetType(self); - if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) - return mlirIntegerAttrGetValueInt(self); - if (mlirIntegerTypeIsSigned(type)) - return mlirIntegerAttrGetValueSInt(self); - return mlirIntegerAttrGetValueUInt(self); - }, - "Returns the value of the integer attribute"); + c.def_property_readonly("value", toPyInt, + "Returns the value of the integer attribute"); + c.def("__int__", toPyInt, + "Converts the value of the integer attribute to a Python int"); c.def_property_readonly_static("static_typeid", [](py::object & /*class*/) -> MlirTypeID { return mlirIntegerAttrGetTypeID(); }); } + +private: + static py::int_ toPyInt(PyIntegerAttribute &self) { + MlirType type = mlirAttributeGetType(self); + if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) + return mlirIntegerAttrGetValueInt(self); + if (mlirIntegerTypeIsSigned(type)) + return mlirIntegerAttrGetValueSInt(self); + return mlirIntegerAttrGetValueUInt(self); + } }; /// Bool Attribute subclass - BoolAttr. @@ -448,10 +449,10 @@ class PyBoolAttribute : public PyConcreteAttribute { }, py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued bool attribute"); - c.def_property_readonly( - "value", - [](PyBoolAttribute &self) { return mlirBoolAttrGetValue(self); }, - "Returns the value of the bool attribute"); + c.def_property_readonly("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); + c.def("__bool__", mlirBoolAttrGetValue, + "Converts the value of the bool attribute to a Python bool"); } }; @@ -595,13 +596,8 @@ class PyStringAttribute : public PyConcreteAttribute { }, py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( - "value", - [](PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); - }, - "Returns the value of the string attribute"); + c.def_property_readonly("value", toPyStr, + "Returns the value of the string attribute"); c.def_property_readonly( "value_bytes", [](PyStringAttribute &self) { @@ -609,6 +605,14 @@ class PyStringAttribute : public PyConcreteAttribute { return py::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); + c.def("__str__", toPyStr, + "Converts the value of the string attribute to a Python str"); + } + +private: + static py::str toPyStr(PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); } }; From 75d1c163520bb122d638a59d86a51394e9a6c2fb Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 29 Aug 2023 10:27:23 +0000 Subject: [PATCH 553/915] [mlir][linalg][transform] Return copy_back op from PadOp. This patch makes the `transform.structured.pad` op return also a handle to the copy op that it inserts. This allows to continue transformation on that op, such as mapping it to a GPU thread. The patch was mainly authored by @springerm as part of the WIP patch https://reviews.llvm.org/D156371, which also has an example usage of this change. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D159088 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 9e039ffa6..a5b4e52d5 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -524,6 +524,7 @@ def __init__( pdl_operation_type = pdl.OperationType.get() super().__init__( + pdl_operation_type, pdl_operation_type, pdl_operation_type, target, From d197a5370e982533a433334f764984135f4c65bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 31 Aug 2023 08:03:54 +0000 Subject: [PATCH 554/915] [mlir][python] Remove __str__ from bindings of StringAttr. This reverts a feature introduced in commit 43211c28cd1cca853357c3b7a3a130f2d3209d04. The goal of that commit was to allow `StringAttr`s to by used transparently wherever Python `str`s are expected. But, as the tests in https://reviews.llvm.org/D159182 reveal, pybind11 doesn't do this conversion based on `__str__` automatically, unlike for the other types introduced in the commit above. At the same time, changing `__str__` breaks the symmetry with other attributes of `print(attr)` printing the assembly of the attribute, so the change probably has more disadvantages than advantages. Reviewed By: springerm, rkayaith Differential Revision: https://reviews.llvm.org/D159255 --- mlir/lib/Bindings/Python/IRAttributes.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 6531d6276..105d2cecf 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -596,8 +596,13 @@ class PyStringAttribute : public PyConcreteAttribute { }, py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly("value", toPyStr, - "Returns the value of the string attribute"); + c.def_property_readonly( + "value", + [](PyStringAttribute &self) { + MlirStringRef stringRef = mlirStringAttrGetValue(self); + return py::str(stringRef.data, stringRef.length); + }, + "Returns the value of the string attribute"); c.def_property_readonly( "value_bytes", [](PyStringAttribute &self) { @@ -605,14 +610,6 @@ class PyStringAttribute : public PyConcreteAttribute { return py::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); - c.def("__str__", toPyStr, - "Converts the value of the string attribute to a Python str"); - } - -private: - static py::str toPyStr(PyStringAttribute &self) { - MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); } }; From 0c66d118e2bb286bd37dac75adaf58d6397fcd6e Mon Sep 17 00:00:00 2001 From: Fangrui Song Date: Fri, 1 Sep 2023 20:53:08 -0700 Subject: [PATCH 555/915] [mlir] Fix duplicate word typos; NFC Those fixes were taken from https://reviews.llvm.org/D137338 --- mlir/include/mlir-c/AffineMap.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/AffineMap.h b/mlir/include/mlir-c/AffineMap.h index 7359b9691..c24c1ced3 100644 --- a/mlir/include/mlir-c/AffineMap.h +++ b/mlir/include/mlir-c/AffineMap.h @@ -100,7 +100,7 @@ mlirAffineMapMinorIdentityGet(MlirContext ctx, intptr_t dims, intptr_t results); /// context. The permutation expression is a non-empty vector of integers. /// The elements of the permutation vector must be continuous from 0 and cannot /// be repeated (i.e. `[1,2,0]` is a valid permutation. `[2,0]` or `[1,1,2]` is -/// an invalid invalid permutation.) The affine map is owned by the context. +/// an invalid permutation.) The affine map is owned by the context. MLIR_CAPI_EXPORTED MlirAffineMap mlirAffineMapPermutationGet( MlirContext ctx, intptr_t size, unsigned *permutation); From 4df1c2bdc2d2549ebb34a6c389cd1be463508663 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 25 Aug 2023 13:21:13 +0000 Subject: [PATCH 556/915] [mlir][linalg][transform][python] Fix optional args of PadOp mix-in. The mix-in did not allow to *not* set many of the arguments, even though they represent optional attributes. Instead, it set default values, which have different semantics in some cases. In other cases, setting the default values is already done by the C++ layer, in which case they are currently redundant and may be wrong in some potential future change in the TD or C++ files. With this patch, `None` is preserved until the generated binding, which handles them as desired. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D158844 --- .../dialects/_structured_transform_ops_ext.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index a5b4e52d5..544171dc2 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -125,7 +125,7 @@ def _get_value_list( def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: if values is None: - return ArrayAttr.get([]) + return None # Turn into a Python list of Python ints. values = _get_value_list(values) @@ -148,7 +148,7 @@ def _get_int_array_array_attr( If the input is None, an empty ArrayAttr is returned. """ if values is None: - return ArrayAttr.get([]) + return None # Make sure the outer level is a list. values = _get_value_list(values) @@ -493,9 +493,7 @@ def __init__( self, target: Union[Operation, OpView, Value], *, - padding_values: Optional[ - Union[ArrayAttr, Sequence[Union[bool, int, float, Attribute]]] - ] = None, + padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, pad_to_multiple_of: OptionalIntList = None, pack_paddings: OptionalIntList = None, @@ -506,17 +504,6 @@ def __init__( loc=None, ip=None, ): - if padding_values is None: - padding_values = [] - if padding_dimensions is None: - padding_dimensions = [] - if pad_to_multiple_of is None: - pad_to_multiple_of = [] - if pack_paddings is None: - pack_paddings = [] - if transpose_paddings is None: - transpose_paddings = [] - padding_dimensions = _get_int_array_attr(padding_dimensions) pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of) pack_paddings = _get_int_array_attr(pack_paddings) From 0366c51912d081fb506330b3ca0a7d602255d72f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 4 Sep 2023 07:59:03 +0000 Subject: [PATCH 557/915] [mlir][python][linalg][transform] Forward missing params in mix-ins. --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 544171dc2..2d0eeb772 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -195,6 +195,8 @@ def __init__( memcpy_op=memcpy_op, alloc_op=alloc_op, bufferize_destination_only=bufferize_destination_only, + loc=loc, + ip=ip, ) @@ -395,6 +397,8 @@ def __init__( static_vector_sizes=static_vector_sizes, scalable_sizes=scalable_sizes, vectorize_nd_extract=vectorize_nd_extract, + loc=loc, + ip=ip, ) From 68623ea69ad32575146c224447ecc5441426fd51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 4 Sep 2023 08:35:43 +0000 Subject: [PATCH 558/915] [mlir][linalg][transform][python] Simplify mix-in of PadOp. This patch removes some manual conversion of mixed Python/attribute arguments to `I64ArrayAttr`s, which turned out to be unnecessary. Interestingly, this change does not depend on the additional attribute builders added in the (currently pending) https://reviews.llvm.org/D159403 patch. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D159419 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 2d0eeb772..792629457 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -508,9 +508,6 @@ def __init__( loc=None, ip=None, ): - padding_dimensions = _get_int_array_attr(padding_dimensions) - pad_to_multiple_of = _get_int_array_attr(pad_to_multiple_of) - pack_paddings = _get_int_array_attr(pack_paddings) transpose_paddings = _get_int_array_array_attr(transpose_paddings) pdl_operation_type = pdl.OperationType.get() From f91e343527294947c520765b8120dcde70623ae1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 4 Sep 2023 08:28:57 +0000 Subject: [PATCH 559/915] [mlir][linalg][transform][python] Refactor TileOp mix-in. This patch simplifies and improves the mix-in of the `TileOp`. In particular: * Accept all types of sizes (static, dynamic, scalable) in a single argument `sizes`. * Use the existing convenience function to dispatch different types of sizes instead of repeating the implementation in the mix-in. * Pass on `None` values as is of optional arguments to the init function of the super class. * Reformat with default indentation width (4 spaces vs 2 spaces). * Add a a test for providing scalable sizes. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D159417 --- .../dialects/_structured_transform_ops_ext.py | 92 +++++++------------ 1 file changed, 31 insertions(+), 61 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 792629457..44d9c1406 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -571,107 +571,77 @@ def __init__( class TileOp: - """Specialization for TileOp class.""" + """Specialization for TileOp class.""" - @overload - def __init__( + @overload + def __init__( self, loop_types: Union[Type, List[Type]], target: Union[Operation, Value], *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - ... + ... - @overload - def __init__( + @overload + def __init__( self, target: Union[Operation, Value, OpView], *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - ... + ... - def __init__( + def __init__( self, loop_types_or_target: Union[Type, List[Type], Operation, Value], target_or_none: Optional[Union[Operation, Value, OpView]] = None, *, - sizes: Optional[ - Union[Sequence[Union[int, IntegerAttr, Operation, Value]], ArrayAttr] - ] = None, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, interchange: OptionalIntList = None, - scalable_sizes: OptionalBoolList = None, loc=None, ip=None, ): - if interchange is None: - interchange = [] - if sizes is None: - sizes = [] - - static_sizes = [] - dynamic_sizes = [] - if isinstance(sizes, ArrayAttr): - sizes_attr = sizes - else: - for size in sizes: - if isinstance(size, int): - static_sizes.append(size) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(_get_op_result_or_value(size)) - sizes_attr = DenseI64ArrayAttr.get(static_sizes) + ( + dynamic_sizes, + static_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(sizes) - num_loops = sum( - v if v == 0 else 1 for v in self.__extract_values(sizes_attr) - ) - if scalable_sizes is None: - scalable_sizes = [False] * len(self.__extract_values(sizes_attr)) + num_loops = sum(v if v == 0 else 1 for v in static_sizes) - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." - else: - loop_types = ( - ([loop_types_or_target] * num_loops) - if isinstance(loop_types_or_target, Type) - else loop_types_or_target - ) - target = target_or_none + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct TileOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none - target = _get_op_result_or_value(target) + target = _get_op_result_or_value(target) - super().__init__( + super().__init__( target.type, loop_types, target, dynamic_sizes=dynamic_sizes, - static_sizes=sizes_attr, + static_sizes=static_sizes, interchange=interchange, scalable_sizes=scalable_sizes, loc=loc, ip=ip, ) - def __extract_values(self, attr: Optional[DenseI64ArrayAttr]) -> List[int]: - if not attr: - return [] - return [element for element in attr] - class TileToForallOp: """Specialization for TileToForallOp class.""" From 3731a938db7aab4d00d1fe9c0d42689830ac189f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 4 Sep 2023 08:07:53 +0000 Subject: [PATCH 560/915] [mlir][linalg][transform][python] Make divisor arg to Multitile optional The mix-in of the `MultiTileSizesOp` set the default value of its `divisor` argument. This repeats information from the tablegen defintion, is not necessary (since the generic code deals with `None` and default values), and has the risk of running out of sync without people noticing. This patch removes the setting of the value and forward `None` to the generic constructor instead. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D159416 --- mlir/python/mlir/dialects/_structured_transform_ops_ext.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 44d9c1406..f368e56f9 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -462,7 +462,7 @@ def match_op_names( class MultiTileSizesOp: - """Specialization for MultitileSizesOp class.""" + """Specialization for MultiTileSizesOp class.""" def __init__( self, @@ -475,8 +475,6 @@ def __init__( loc=None, ip=None, ): - if divisor is None: - divisor = 1 super().__init__( result_type, result_type, From 745692d6b99279cd3a3aabbcbe7f881ed80363b6 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Tue, 5 Sep 2023 08:02:21 +0000 Subject: [PATCH 561/915] [mlir][Python] Fix conversion of non-zero offset memrefs to np.arrays Memref descriptors contain an `offset` field that denotes the start of the content of the memref relative to the `alignedPtr`. This offset is not considered when converting a memref descriptor to a np.array in the Python runtime library, essentially treating all memrefs as if they had an offset of zero. This patch introduces the necessary pointer arithmetic to find the actual beginning of the memref contents to the memref->numpy conversion functions. There is an ongoing discussion about whether the `offset` field is needed at all in the memref descriptor. Until that is decided, the Python runtime and CRunnerUtils should still correctly implement the offset handling. Related: https://reviews.llvm.org/D157008 Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D158494 --- mlir/python/mlir/runtime/np_to_memref.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 51433d75a..0a3b41104 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -114,13 +114,21 @@ def get_unranked_memref_descriptor(nparray): d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) return d +def move_aligned_ptr_by_offset(aligned_ptr, offset): + """Moves the supplied ctypes pointer ahead by `offset` elements.""" + aligned_addr = ctypes.addressof(aligned_ptr.contents) + elem_size = ctypes.sizeof(aligned_ptr.contents) + shift = offset * elem_size + content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) + return content_ptr def unranked_memref_to_numpy(unranked_memref, np_dtype): """Converts unranked memrefs to numpy arrays.""" ctp = as_ctype(np_dtype) descriptor = make_nd_memref_descriptor(unranked_memref[0].rank, ctp) val = ctypes.cast(unranked_memref[0].descriptor, ctypes.POINTER(descriptor)) - np_arr = np.ctypeslib.as_array(val[0].aligned, shape=val[0].shape) + content_ptr = move_aligned_ptr_by_offset(val[0].aligned, val[0].offset) + np_arr = np.ctypeslib.as_array(content_ptr, shape=val[0].shape) strided_arr = np.lib.stride_tricks.as_strided( np_arr, np.ctypeslib.as_array(val[0].shape), @@ -131,8 +139,9 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype): def ranked_memref_to_numpy(ranked_memref): """Converts ranked memrefs to numpy arrays.""" + content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset) np_arr = np.ctypeslib.as_array( - ranked_memref[0].aligned, shape=ranked_memref[0].shape + content_ptr, shape=ranked_memref[0].shape ) strided_arr = np.lib.stride_tricks.as_strided( np_arr, From e3f102988e912253ca5701d43ecb50aa71172255 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 1 Sep 2023 09:11:35 +0000 Subject: [PATCH 562/915] [mlir][python] Create all missing attribute builders. This patch adds attribute builders for all buildable attributes from the builtin dialect that did not previously have any. These builders can be used to construct attributes of a particular type identified by a string from a Python argument without knowing the details of how to pass that Python argument to the attribute constructor. This is used, for example, in the generated code of the Python bindings of ops. The list of "all" attributes was produced with: ( grep -h "ods_ir.AttrBuilder.get" $(find ../build/ -name "*_ops_gen.py") \ | cut -f2 -d"'" git grep -ho "^def [a-zA-Z0-9_]*" -- include/mlir/IR/CommonAttrConstraints.td \ | cut -f2 -d" " ) | sort -u Then, I only retained those that had an occurence in `mlir/include/mlir/IR`. In particular, this drops many dialect-specific attributes; registering those builders is something that those dialects should do. Finally, I removed those attrbiutes that had a match in `mlir/python/mlir/ir.py` already and implemented the remaining ones. The only ones that still miss a builder now are the following: * Represent more than one possible attribute type: - `Any.*Attr` (9x) - `IntNonNegative` - `IntPositive` - `IsNullAttr` - `ElementsAttr` * I am not sure what "constant attributes" are: - `ConstBoolAttrFalse` - `ConstBoolAttrTrue` - `ConstUnitAttr` * `Location` not exposed by Python bindings: - `LocationArrayAttr` - `LocationAttr` * `get` function not implemented in Python bindings: - `StringElementsAttr` This patch also fixes a compilation problem with `I64SmallVectorArrayAttr`. Reviewed By: makslevental, rkayaith Differential Revision: https://reviews.llvm.org/D159403 --- mlir/python/mlir/ir.py | 157 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 36c49fe6f..43553f311 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -16,16 +16,36 @@ def decorator_builder(func): return decorator_builder +@register_attribute_builder("AffineMapAttr") +def _affineMapAttr(x, context): + return AffineMapAttr.get(x) + + @register_attribute_builder("BoolAttr") def _boolAttr(x, context): return BoolAttr.get(x, context=context) +@register_attribute_builder("DictionaryAttr") +def _dictAttr(x, context): + return DictAttr.get(x, context=context) + + @register_attribute_builder("IndexAttr") def _indexAttr(x, context): return IntegerAttr.get(IndexType.get(context=context), x) +@register_attribute_builder("I1Attr") +def _i1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(1, context=context), x) + + +@register_attribute_builder("I8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signless(8, context=context), x) + + @register_attribute_builder("I16Attr") def _i16Attr(x, context): return IntegerAttr.get(IntegerType.get_signless(16, context=context), x) @@ -41,6 +61,16 @@ def _i64Attr(x, context): return IntegerAttr.get(IntegerType.get_signless(64, context=context), x) +@register_attribute_builder("SI1Attr") +def _si1Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(1, context=context), x) + + +@register_attribute_builder("SI8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(8, context=context), x) + + @register_attribute_builder("SI16Attr") def _si16Attr(x, context): return IntegerAttr.get(IntegerType.get_signed(16, context=context), x) @@ -51,6 +81,36 @@ def _si32Attr(x, context): return IntegerAttr.get(IntegerType.get_signed(32, context=context), x) +@register_attribute_builder("SI64Attr") +def _si64Attr(x, context): + return IntegerAttr.get(IntegerType.get_signed(64, context=context), x) + + +@register_attribute_builder("UI1Attr") +def _ui1Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(1, context=context), x) + + +@register_attribute_builder("UI8Attr") +def _i8Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x) + + +@register_attribute_builder("UI16Attr") +def _ui16Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(16, context=context), x) + + +@register_attribute_builder("UI32Attr") +def _ui32Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(32, context=context), x) + + +@register_attribute_builder("UI64Attr") +def _ui64Attr(x, context): + return IntegerAttr.get(IntegerType.get_unsigned(64, context=context), x) + + @register_attribute_builder("F32Attr") def _f32Attr(x, context): return FloatAttr.get_f32(x, context=context) @@ -84,11 +144,39 @@ def _flatSymbolRefAttr(x, context): return FlatSymbolRefAttr.get(x, context=context) +@register_attribute_builder("UnitAttr") +def _unitAttr(x, context): + if x: + return UnitAttr.get(context=context) + else: + return None + + @register_attribute_builder("ArrayAttr") def _arrayAttr(x, context): return ArrayAttr.get(x, context=context) +@register_attribute_builder("AffineMapArrayAttr") +def _affineMapArrayAttr(x, context): + return ArrayAttr.get([_affineMapAttr(v, context) for v in x]) + + +@register_attribute_builder("BoolArrayAttr") +def _boolArrayAttr(x, context): + return ArrayAttr.get([_boolAttr(v, context) for v in x]) + + +@register_attribute_builder("DictArrayAttr") +def _dictArrayAttr(x, context): + return ArrayAttr.get([_dictAttr(v, context) for v in x]) + + +@register_attribute_builder("FlatSymbolRefArrayAttr") +def _flatSymbolRefArrayAttr(x, context): + return ArrayAttr.get([_flatSymbolRefAttr(v, context) for v in x]) + + @register_attribute_builder("I32ArrayAttr") def _i32ArrayAttr(x, context): return ArrayAttr.get([_i32Attr(v, context) for v in x]) @@ -99,6 +187,16 @@ def _i64ArrayAttr(x, context): return ArrayAttr.get([_i64Attr(v, context) for v in x]) +@register_attribute_builder("I64SmallVectorArrayAttr") +def _i64SmallVectorArrayAttr(x, context): + return _i64ArrayAttr(x, context=context) + + +@register_attribute_builder("IndexListArrayAttr") +def _indexListArrayAttr(x, context): + return ArrayAttr.get([_i64ArrayAttr(v, context) for v in x]) + + @register_attribute_builder("F32ArrayAttr") def _f32ArrayAttr(x, context): return ArrayAttr.get([_f32Attr(v, context) for v in x]) @@ -109,6 +207,41 @@ def _f64ArrayAttr(x, context): return ArrayAttr.get([_f64Attr(v, context) for v in x]) +@register_attribute_builder("StrArrayAttr") +def _strArrayAttr(x, context): + return ArrayAttr.get([_stringAttr(v, context) for v in x]) + + +@register_attribute_builder("SymbolRefArrayAttr") +def _symbolRefArrayAttr(x, context): + return ArrayAttr.get([_symbolRefAttr(v, context) for v in x]) + + +@register_attribute_builder("DenseF32ArrayAttr") +def _denseF32ArrayAttr(x, context): + return DenseF32ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseF64ArrayAttr") +def _denseF64ArrayAttr(x, context): + return DenseF64ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI8ArrayAttr") +def _denseI8ArrayAttr(x, context): + return DenseI8ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI16ArrayAttr") +def _denseI16ArrayAttr(x, context): + return DenseI16ArrayAttr.get(x, context=context) + + +@register_attribute_builder("DenseI32ArrayAttr") +def _denseI32ArrayAttr(x, context): + return DenseI32ArrayAttr.get(x, context=context) + + @register_attribute_builder("DenseI64ArrayAttr") def _denseI64ArrayAttr(x, context): return DenseI64ArrayAttr.get(x, context=context) @@ -132,6 +265,30 @@ def _typeArrayAttr(x, context): try: import numpy as np + @register_attribute_builder("F64ElementsAttr") + def _f64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=F64Type.get(context=context), + context=context, + ) + + @register_attribute_builder("I32ElementsAttr") + def _i32ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int32), + type=IntegerType.get_signed(32, context=context), + context=context, + ) + + @register_attribute_builder("I64ElementsAttr") + def _i64ElementsAttr(x, context): + return DenseElementsAttr.get( + np.array(x, dtype=np.int64), + type=IntegerType.get_signed(64, context=context), + context=context, + ) + @register_attribute_builder("IndexElementsAttr") def _indexElementsAttr(x, context): return DenseElementsAttr.get( From 0b42f596259689cc98b69289568ba9b271bf23bf Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 11 Sep 2023 13:22:35 -0700 Subject: [PATCH 563/915] Lazy initialize diagnostic when handling MLIR properties (#65868) Instead of eagerly creating a diagnostic that will be discarded in the normal case, switch to lazy initialization on error. --- mlir/lib/CAPI/IR/IR.cpp | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ccdae1424..ef234a912 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -370,16 +370,21 @@ static LogicalResult inferOperationTypes(OperationState &state) { if (!properties && info->getOpPropertyByteSize() > 0 && !attributes.empty()) { auto prop = std::make_unique(info->getOpPropertyByteSize()); properties = OpaqueProperties(prop.get()); - InFlightDiagnostic diag = emitError(state.location) - << " failed properties conversion while building " - << state.name.getStringRef() << " with `" - << attributes << "`: "; - if (failed(info->setOpPropertiesFromAttribute(state.name, properties, - attributes, &diag))) { - return failure(); + if (properties) { + std::unique_ptr diagnostic; + auto getDiag = [&]() -> InFlightDiagnostic & { + if (!diagnostic) { + diagnostic = std::make_unique( + emitError(state.location) + << " failed properties conversion while building " + << state.name.getStringRef() << " with `" << attributes << "`: "); + } + return *diagnostic; + }; + if (failed(info->setOpPropertiesFromAttribute(state.name, properties, + attributes, getDiag))) + return failure(); } - diag.abandon(); - if (succeeded(inferInterface->inferReturnTypes( context, state.location, state.operands, attributes, properties, state.regions, state.types))) { From b5f8d5c2639de3be98da1c4c692ba5c4caabf77f Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 11 Sep 2023 14:10:03 -0700 Subject: [PATCH 564/915] [mlir] Make it possible to build a DenseResourceElementsAttr from untyped memory. (#66009) Exposes the existing `get(ShapedType, StringRef, AsmResourceBlob)` builder publicly (was protected) and adds a CAPI `mlirUnmanagedDenseBlobResourceElementsAttrGet`. While such a generic construction interface is a big help when it comes to interop, it is also necessary for creating resources that don't have a standard C type (i.e. f16, the f8s, etc). Previously reviewed/approved as part of https://reviews.llvm.org/D157064 --- mlir/include/mlir-c/BuiltinAttributes.h | 7 +++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 631981924..93c4ed569 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -600,6 +600,13 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, intptr_t numElements, const double *elements); +/// Unlike the typed accessors above, constructs the attribute with a raw +/// data buffer and no type/alignment checking. Use a more strongly typed +/// accessor if possible. +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, const void *data, + size_t dataLength); + /// Returns the pos-th value (flat contiguous indexing) of a specific type /// contained by the given dense resource elements attribute. MLIR_CAPI_EXPORTED bool diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index de221ddbf..84a958d01 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -852,6 +852,14 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, return getDenseResource(shapedType, name, numElements, elements); } +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, const void *data, + size_t dataLength) { + return wrap(DenseResourceElementsAttr::get( + llvm::cast(unwrap(shapedType)), unwrap(name), + UnmanagedAsmResourceBlob::allocateInferAlign( + llvm::ArrayRef(static_cast(data), dataLength)))); +} template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { From 294c9140308306266977e45281e782fb35bdee50 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin <39276703+unterumarmung@users.noreply.github.com> Date: Tue, 12 Sep 2023 08:02:19 +0300 Subject: [PATCH 565/915] =?UTF-8?q?[mlir][arith]=20Rename=20operations:=20?= =?UTF-8?q?`maxf`=20=E2=86=92=20`maximumf`,=20`minf`=20=E2=86=92=20`minimu?= =?UTF-8?q?mf`=20(#65800)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit addresses Task 1.2 of the mentioned RFC. By renaming these operations, we align their names with LLVM intrinsics that have corresponding semantics. --- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 62730d9ca..6f9d72164 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -527,28 +527,28 @@ def _binary_mul(self, lhs: Value, rhs: Value) -> Value: def _binary_max_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result + return arith.MaximumFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max' operands: {lhs}, {rhs}") def _binary_max_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MaxFOp(lhs, rhs).result + return arith.MaximumFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MaxUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'max_unsigned' operands: {lhs}, {rhs}") def _binary_min_signed(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result + return arith.MinimumFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinSIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min' operands: {lhs}, {rhs}") def _binary_min_unsigned(self, lhs: Value, rhs: Value) -> Value: if _is_floating_point_type(lhs.type): - return arith.MinFOp(lhs, rhs).result + return arith.MinimumFOp(lhs, rhs).result if _is_integer_type(lhs.type) or _is_index_type(lhs.type): return arith.MinUIOp(lhs, rhs).result raise NotImplementedError("Unsupported 'min_unsigned' operands: {lhs}, {rhs}") From 0a241dc54a681824fa1db1c4029a70e5063d4689 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Tue, 12 Sep 2023 14:04:38 +0000 Subject: [PATCH 566/915] [mlir][bufferization] Remove allow-return-allocs and create-deallocs pass options, remove bufferization.escape attribute This is the first commit in a series with the goal to rework the BufferDeallocation pass. Currently, this pass heavily relies on copies to perform correct deallocations, which leads to very slow code and potentially high memory usage. Additionally, there are unsupported cases such as returning memrefs which this series of commits aims to add support for as well. This first commit removes the deallocation capabilities of one-shot-bufferization.One-shot-bufferization should never deallocate any memrefs as this should be entirely handled by the buffer-deallocation pass going forward. This means the allow-return-allocs pass option will default to true now, create-deallocs defaults to false and they, as well as the escape attribute indicating whether a memref escapes the current region, will be removed. The documentation should w.r.t. these pass option changes should also be updated in this commit. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D156662 --- .../_bufferization_transform_ops_ext.py | 136 +++++++++--------- 1 file changed, 64 insertions(+), 72 deletions(-) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py index ead337282..8dbcb3b69 100644 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -54,79 +54,71 @@ def __init__( class OneShotBufferizeOp: - """Specialization for OneShotBufferizeOp class.""" + """Specialization for OneShotBufferizeOp class.""" - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + ... - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + ... - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - allow_return_allocs: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target - super().__init__( - transformed_type, - target, - allow_return_allocs=allow_return_allocs, - allow_unknown_ops=allow_unknown_ops, - bufferize_function_boundaries=bufferize_function_boundaries, - create_deallocs=create_deallocs, - function_boundary_type_conversion=function_boundary_type_conversion, - memcpy_op=memcpy_op, - print_conflicts=print_conflicts, - test_analysis_only=test_analysis_only, - loc=loc, - ip=ip, - ) + super().__init__( + transformed_type, + target, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + loc=loc, + ip=ip, + ) From 147495124f91f8d51b20502bb882d60d86f90219 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Wed, 13 Sep 2023 13:39:47 +0000 Subject: [PATCH 567/915] Revert "[mlir][bufferization] Remove allow-return-allocs and create-deallocs pass options, remove bufferization.escape attribute" This reverts commit 0a241dc54a681824fa1db1c4029a70e5063d4689. This caused problems in downstream projects. We are reverting to give them more time for integration. --- .../_bufferization_transform_ops_ext.py | 136 +++++++++--------- 1 file changed, 72 insertions(+), 64 deletions(-) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py index 8dbcb3b69..ead337282 100644 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -54,71 +54,79 @@ def __init__( class OneShotBufferizeOp: - """Specialization for OneShotBufferizeOp class.""" + """Specialization for OneShotBufferizeOp class.""" - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + ... - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + ... - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + create_deallocs: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target - super().__init__( - transformed_type, - target, - allow_unknown_ops=allow_unknown_ops, - bufferize_function_boundaries=bufferize_function_boundaries, - function_boundary_type_conversion=function_boundary_type_conversion, - memcpy_op=memcpy_op, - print_conflicts=print_conflicts, - test_analysis_only=test_analysis_only, - loc=loc, - ip=ip, - ) + super().__init__( + transformed_type, + target, + allow_return_allocs=allow_return_allocs, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + create_deallocs=create_deallocs, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + loc=loc, + ip=ip, + ) From 7d2958af79f576c3fe340a1bfd50c4f6da798362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 14 Sep 2023 17:24:16 +0200 Subject: [PATCH 568/915] [mlir][linalg][transform][python] Drop _get_op_result... from mix-ins. (#65726) `_get_op_result_or_value` was used in mix-ins to unify the handling of op results and values. However, that function is now called in the generated constructors, such that doing so in the mix-ins is not necessary anymore. --- .../dialects/_structured_transform_ops_ext.py | 47 +++++++------------ 1 file changed, 18 insertions(+), 29 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index f368e56f9..212fbc5ba 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -4,7 +4,6 @@ try: from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value from ..dialects import pdl, transform except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -101,7 +100,7 @@ def _dispatch_mixed_values( static_values.append(size) else: static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(_get_op_result_or_value(size)) + dynamic_values.append(size) static_values = DenseI64ArrayAttr.get(static_values) return (dynamic_values, packed_values, static_values) @@ -204,9 +203,7 @@ class DecomposeOp: """Specialization for DecomposeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) class FuseIntoContainingOp: @@ -277,9 +274,7 @@ class GeneralizeOp: """Specialization for GeneralizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__( - pdl.OperationType.get(), _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) class InterchangeOp: @@ -296,7 +291,7 @@ def __init__( pdl_operation_type = pdl.OperationType.get() super().__init__( pdl_operation_type, - _get_op_result_or_value(target), + target, iterator_interchange=iterator_interchange, loc=loc, ip=ip, @@ -415,7 +410,7 @@ def match_op_names( loc=None, ip=None, ): - ... + ... @overload @classmethod @@ -428,7 +423,7 @@ def match_op_names( loc=None, ip=None, ): - ... + ... @classmethod def match_op_names( @@ -441,20 +436,20 @@ def match_op_names( ip=None, ): if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_names - names = names_or_none + result_type = result_type_or_target + target = target_or_names + names = names_or_none else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - names = target_or_names + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names if isinstance(names, str): - names = [names] + names = [names] return cls( result_type, - _get_op_result_or_value(target), + target, ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), loc=loc, ip=ip, @@ -479,7 +474,7 @@ def __init__( result_type, result_type, result_type, - _get_op_result_or_value(target), + target, dimension=dimension, target_size=target_size, divisor=divisor, @@ -530,9 +525,7 @@ class ScalarizeOp: def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): pdl_operation_type = pdl.OperationType.get() - super().__init__( - pdl_operation_type, _get_op_result_or_value(target), loc=loc, ip=ip - ) + super().__init__(pdl_operation_type, target, loc=loc, ip=ip) class SplitOp: @@ -552,9 +545,7 @@ def __init__( dynamic_split_point = None else: static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = _get_op_result_or_value(split_point) - - target = _get_op_result_or_value(target) + dynamic_split_point = split_point super().__init__( target.type, @@ -626,8 +617,6 @@ def __init__( ) target = target_or_none - target = _get_op_result_or_value(target) - super().__init__( target.type, loop_types, @@ -750,7 +739,7 @@ def __init__( pdl_operation_type = pdl.OperationType.get() super().__init__( pdl_operation_type, - _get_op_result_or_value(target), + target, disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, vectorize_nd_extract=vectorize_nd_extract, From 952a9d09e9495122dfded138f459cab2bb178dae Mon Sep 17 00:00:00 2001 From: Arthur Eubanks Date: Thu, 14 Sep 2023 14:10:14 -0700 Subject: [PATCH 569/915] [NFC][CodeGen] Change CodeGenOpt::Level/CodeGenFileType into enum classes (#66295) This will make it easy for callers to see issues with and fix up calls to createTargetMachine after a future change to the params of TargetMachine. This matches other nearby enums. For downstream users, this should be a fairly straightforward replacement, e.g. s/CodeGenOpt::Aggressive/CodeGenOptLevel::Aggressive or s/CGFT_/CodeGenFileType:: --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 067cf677e..507be9171 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -53,12 +53,11 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, // Create a transformer to run all LLVM optimization passes at the // specified optimization level. - auto llvmOptLevel = static_cast(optLevel); auto transformer = mlir::makeOptimizingTransformer( - llvmOptLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); + optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); ExecutionEngineOptions jitOptions; jitOptions.transformer = transformer; - jitOptions.jitCodeGenOptLevel = llvmOptLevel; + jitOptions.jitCodeGenOptLevel = static_cast(optLevel); jitOptions.sharedLibPaths = libPaths; jitOptions.enableObjectDump = enableObjectDump; auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions); From 2ccb0ddb3b1e5d025f7106c6ae108f3c7e71151d Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 14 Sep 2023 18:45:29 -0700 Subject: [PATCH 570/915] [mlir] Add Python bindings for DenseResourceElementsAttr. (#66319) Only construction and type casting are implemented. The method to create is explicitly named "unsafe" and the documentation calls out what the caller is responsible for. There really isn't a better way to do this and retain the power-user feature this represents. --- mlir/include/mlir-c/BuiltinAttributes.h | 24 +++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 103 ++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 123 ++++++++++++---------- 3 files changed, 185 insertions(+), 65 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 93c4ed569..01d1b6008 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -558,6 +558,23 @@ mlirDenseElementsAttrGetRawData(MlirAttribute attr); // Resource blob attributes. //===----------------------------------------------------------------------===// +MLIR_CAPI_EXPORTED bool +mlirAttributeIsADenseResourceElements(MlirAttribute attr); + +/// Unlike the typed accessors below, constructs the attribute with a raw +/// data buffer and no type/alignment checking. Use a more strongly typed +/// accessor if possible. If dataIsMutable is false, then an immutable +/// AsmResourceBlob will be created and that passed data contents will be +/// treated as const. +/// If the deleter is non NULL, then it will be called when the data buffer +/// can no longer be accessed (passing userData to it). +MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, + size_t dataAlignment, bool dataIsMutable, + void (*deleter)(void *userData, const void *data, size_t size, + size_t align), + void *userData); + MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int *elements); @@ -600,13 +617,6 @@ mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, intptr_t numElements, const double *elements); -/// Unlike the typed accessors above, constructs the attribute with a raw -/// data buffer and no type/alignment checking. Use a more strongly typed -/// accessor if possible. -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet( - MlirType shapedType, MlirStringRef name, const void *data, - size_t dataLength); - /// Returns the pos-th value (flat contiguous indexing) of a specific type /// contained by the given dense resource elements attribute. MLIR_CAPI_EXPORTED bool diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 105d2cecf..94fa2527e 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -72,6 +72,32 @@ or 255), then a splat will be created. type or if the buffer does not meet expectations. )"; +static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = + R"(Gets a DenseResourceElementsAttr from a Python buffer or array. + +This function does minimal validation or massaging of the data, and it is +up to the caller to ensure that the buffer meets the characteristics +implied by the shape. + +The backing buffer and any user objects will be retained for the lifetime +of the resource blob. This is typically bounded to the context but the +resource can have a shorter lifespan depending on how it is used in +subsequent processing. + +Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + +Returns: + DenseResourceElementsAttr on success. + +Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. +)"; + namespace { static MlirStringRef toMlirStringRef(const std::string &s) { @@ -997,6 +1023,82 @@ class PyDenseIntElementsAttribute } }; +class PyDenseResourceElementsAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = + mlirAttributeIsADenseResourceElements; + static constexpr const char *pyClassName = "DenseResourceElementsAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + + static PyDenseResourceElementsAttribute + getFromBuffer(py::buffer buffer, std::string name, PyType type, + std::optional alignment, bool isMutable, + DefaultingPyMlirContext contextWrapper) { + if (!mlirTypeIsAShaped(type)) { + throw std::invalid_argument( + "Constructing a DenseResourceElementsAttr requires a ShapedType."); + } + + // Do not request any conversions as we must ensure to use caller + // managed memory. + int flags = PyBUF_STRIDES; + std::unique_ptr view = std::make_unique(); + if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { + throw py::error_already_set(); + } + + // This scope releaser will only release if we haven't yet transferred + // ownership. + auto freeBuffer = llvm::make_scope_exit([&]() { + if (view) + PyBuffer_Release(view.get()); + }); + + if (!PyBuffer_IsContiguous(view.get(), 'A')) { + throw std::invalid_argument("Contiguous buffer is required."); + } + + // Infer alignment to be the stride of one element if not explicit. + size_t inferredAlignment; + if (alignment) + inferredAlignment = *alignment; + else + inferredAlignment = view->strides[view->ndim - 1]; + + // The userData is a Py_buffer* that the deleter owns. + auto deleter = [](void *userData, const void *data, size_t size, + size_t align) { + Py_buffer *ownedView = static_cast(userData); + PyBuffer_Release(ownedView); + delete ownedView; + }; + + size_t rawBufferSize = view->len; + MlirAttribute attr = mlirUnmanagedDenseResourceElementsAttrGet( + type, toMlirStringRef(name), view->buf, rawBufferSize, + inferredAlignment, isMutable, deleter, static_cast(view.get())); + if (mlirAttributeIsNull(attr)) { + throw std::invalid_argument( + "DenseResourceElementsAttr could not be constructed from the given " + "buffer. " + "This may mean that the Python buffer layout does not match that " + "MLIR expected layout and is a bug."); + } + view.release(); + return PyDenseResourceElementsAttribute(contextWrapper->getRef(), attr); + } + + static void bindDerived(ClassTy &c) { + c.def_static("get_from_buffer", + PyDenseResourceElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("name"), py::arg("type"), + py::arg("alignment") = py::none(), + py::arg("is_mutable") = false, py::arg("context") = py::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); + } +}; + class PyDictAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsADictionary; @@ -1273,6 +1375,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 84a958d01..b3066ee0c 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -770,6 +770,30 @@ const void *mlirDenseElementsAttrGetRawData(MlirAttribute attr) { // Resource blob attributes. //===----------------------------------------------------------------------===// +bool mlirAttributeIsADenseResourceElements(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirUnmanagedDenseResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, void *data, size_t dataLength, + size_t dataAlignment, bool dataIsMutable, + void (*deleter)(void *userData, const void *data, size_t size, + size_t align), + void *userData) { + AsmResourceBlob::DeleterFn cppDeleter = {}; + if (deleter) { + cppDeleter = [deleter, userData](void *data, size_t size, size_t align) { + deleter(userData, data, size, align); + }; + } + AsmResourceBlob blob( + llvm::ArrayRef(static_cast(data), dataLength), + dataAlignment, std::move(cppDeleter), dataIsMutable); + return wrap( + DenseResourceElementsAttr::get(llvm::cast(unwrap(shapedType)), + unwrap(name), std::move(blob))); +} + template static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, intptr_t numElements, const T *elements) { @@ -778,139 +802,122 @@ static MlirAttribute getDenseResource(MlirType shapedType, MlirStringRef name, llvm::ArrayRef(elements, numElements)))); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseBoolResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseUInt8ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const uint8_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute -mlirUnmanagedDenseUInt16ResourceElementsAttrGet(MlirType shapedType, - MlirStringRef name, - intptr_t numElements, - const uint16_t *elements) { +MlirAttribute mlirUnmanagedDenseUInt16ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint16_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute -mlirUnmanagedDenseUInt32ResourceElementsAttrGet(MlirType shapedType, - MlirStringRef name, - intptr_t numElements, - const uint32_t *elements) { +MlirAttribute mlirUnmanagedDenseUInt32ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint32_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute -mlirUnmanagedDenseUInt64ResourceElementsAttrGet(MlirType shapedType, - MlirStringRef name, - intptr_t numElements, - const uint64_t *elements) { +MlirAttribute mlirUnmanagedDenseUInt64ResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const uint64_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseInt8ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int8_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseInt16ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int16_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseInt32ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int32_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseInt64ResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const int64_t *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( +MlirAttribute mlirUnmanagedDenseFloatResourceElementsAttrGet( MlirType shapedType, MlirStringRef name, intptr_t numElements, const float *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute -mlirUnmanagedDenseDoubleResourceElementsAttrGet(MlirType shapedType, - MlirStringRef name, - intptr_t numElements, - const double *elements) { +MlirAttribute mlirUnmanagedDenseDoubleResourceElementsAttrGet( + MlirType shapedType, MlirStringRef name, intptr_t numElements, + const double *elements) { return getDenseResource(shapedType, name, numElements, elements); } -MLIR_CAPI_EXPORTED MlirAttribute mlirUnmanagedDenseBlobResourceElementsAttrGet( - MlirType shapedType, MlirStringRef name, const void *data, - size_t dataLength) { - return wrap(DenseResourceElementsAttr::get( - llvm::cast(unwrap(shapedType)), unwrap(name), - UnmanagedAsmResourceBlob::allocateInferAlign( - llvm::ArrayRef(static_cast(data), dataLength)))); -} - template static T getDenseResourceVal(MlirAttribute attr, intptr_t pos) { return (*llvm::cast(unwrap(attr)).tryGetAsArrayRef())[pos]; } -MLIR_CAPI_EXPORTED bool -mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +bool mlirDenseBoolResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED uint8_t -mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +uint8_t mlirDenseUInt8ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED uint16_t -mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +uint16_t mlirDenseUInt16ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED uint32_t -mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +uint32_t mlirDenseUInt32ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED uint64_t -mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +uint64_t mlirDenseUInt64ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED int8_t -mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +int8_t mlirDenseInt8ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED int16_t -mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +int16_t mlirDenseInt16ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED int32_t -mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +int32_t mlirDenseInt32ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED int64_t -mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +int64_t mlirDenseInt64ResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED float -mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +float mlirDenseFloatResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } -MLIR_CAPI_EXPORTED double -mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, intptr_t pos) { +double mlirDenseDoubleResourceElementsAttrGetValue(MlirAttribute attr, + intptr_t pos) { return getDenseResourceVal(attr, pos); } From fc76244bebdc0c097519fe21159d0056ad9f5f21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Fri, 15 Sep 2023 09:06:07 +0200 Subject: [PATCH 571/915] [mlir][transform][lingalg][python] Replace pdl.operation => transform.any_op. (#66392) For some reason, the mix-ins of the Python bindings of this dialect used the PDL type for "any op". However, PDL isn't involved here, so it makes more sense to use the corresponding type of the transform dialect. This PR changes that. --- .../dialects/_structured_transform_ops_ext.py | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 212fbc5ba..c5134b6e7 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -4,7 +4,7 @@ try: from ..ir import * - from ..dialects import pdl, transform + from ..dialects import transform except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -203,7 +203,8 @@ class DecomposeOp: """Specialization for DecomposeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) class FuseIntoContainingOp: @@ -274,7 +275,8 @@ class GeneralizeOp: """Specialization for GeneralizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - super().__init__(pdl.OperationType.get(), target, loc=loc, ip=ip) + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) class InterchangeOp: @@ -288,9 +290,9 @@ def __init__( loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() + transformed_type = transform.AnyOpType.get() super().__init__( - pdl_operation_type, + transformed_type, target, iterator_interchange=iterator_interchange, loc=loc, @@ -503,11 +505,11 @@ def __init__( ): transpose_paddings = _get_int_array_array_attr(transpose_paddings) - pdl_operation_type = pdl.OperationType.get() + any_op_type = transform.AnyOpType.get() super().__init__( - pdl_operation_type, - pdl_operation_type, - pdl_operation_type, + any_op_type, + any_op_type, + any_op_type, target, padding_values=padding_values, padding_dimensions=padding_dimensions, @@ -524,8 +526,8 @@ class ScalarizeOp: """Specialization for ScalarizeOp class.""" def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - pdl_operation_type = pdl.OperationType.get() - super().__init__(pdl_operation_type, target, loc=loc, ip=ip) + result_type = transform.AnyOpType.get() + super().__init__(result_type, target, loc=loc, ip=ip) class SplitOp: @@ -736,9 +738,9 @@ def __init__( loc=None, ip=None, ): - pdl_operation_type = pdl.OperationType.get() + transformed_type = transform.AnyOpType.get() super().__init__( - pdl_operation_type, + transformed_type, target, disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, From 60e959d9bebe2141f4e99947aa40f86afdb356ee Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Mon, 18 Sep 2023 16:44:48 +0200 Subject: [PATCH 572/915] [mlir][bufferization] Remove allow-return-allocs and create-deallocs pass options, remove bufferization.escape attribute (#66619) This commit removes the deallocation capabilities of one-shot-bufferization. One-shot-bufferization should never deallocate any memrefs as this should be entirely handled by the ownership-based-buffer-deallocation pass going forward. This means the `allow-return-allocs` pass option will default to true now, `create-deallocs` defaults to false and they, as well as the escape attribute indicating whether a memref escapes the current region, will be removed. A new `allow-return-allocs-from-loops` option is added as a temporary workaround for some bufferization limitations. --- .../dialects/_bufferization_transform_ops_ext.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py index ead337282..7e6c1b81c 100644 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py @@ -62,10 +62,9 @@ def __init__( transformed_type: Type, target: Union[Operation, OpView, Value], *, - allow_return_allocs: Optional[bool] = None, + allow_return_allocs_from_loops: Optional[bool] = None, allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, print_conflicts: Optional[bool] = None, @@ -80,10 +79,9 @@ def __init__( self, target: Union[Operation, OpView, Value], *, - allow_return_allocs: Optional[bool] = None, + allow_return_allocs_from_loops: Optional[bool] = None, allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, print_conflicts: Optional[bool] = None, @@ -98,10 +96,9 @@ def __init__( transformed_type_or_target: Type, target_or_none: Optional[Union[Operation, OpView, Value]] = None, *, - allow_return_allocs: Optional[bool] = None, + allow_return_allocs_from_loops: Optional[bool] = None, allow_unknown_ops: Optional[bool] = None, bufferize_function_boundaries: Optional[bool] = None, - create_deallocs: Optional[bool] = None, function_boundary_type_conversion: Optional[Enum] = None, memcpy_op: Optional[str] = None, print_conflicts: Optional[bool] = None, @@ -119,10 +116,9 @@ def __init__( super().__init__( transformed_type, target, - allow_return_allocs=allow_return_allocs, + allow_return_allocs_from_loops=allow_return_allocs_from_loops, allow_unknown_ops=allow_unknown_ops, bufferize_function_boundaries=bufferize_function_boundaries, - create_deallocs=create_deallocs, function_boundary_type_conversion=function_boundary_type_conversion, memcpy_op=memcpy_op, print_conflicts=print_conflicts, From 50527bf907f1445f5d1190e950b700ac13ad1017 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 18 Sep 2023 20:12:12 -0700 Subject: [PATCH 573/915] [mlir][c] Expose AsmState. (#66693) Enable usage where capturing AsmState is good (e.g., avoiding creating AsmState over and over again when walking IR and printing). This also only changes one C API to verify plumbing. But using the AsmState makes the cost more explicit than the flags interface (which hides the traversals and construction here) and also enables a more efficient usage C side. --- mlir/include/mlir-c/IR.h | 26 ++++++++++++++- mlir/include/mlir/CAPI/IR.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 4 ++- mlir/lib/CAPI/IR/IR.cpp | 49 +++++++++++++++++++++++++++-- 4 files changed, 76 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b5c6a3094..68eccab6d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -48,6 +48,7 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirAsmState, void); DEFINE_C_API_STRUCT(MlirBytecodeWriterConfig, void); DEFINE_C_API_STRUCT(MlirContext, void); DEFINE_C_API_STRUCT(MlirDialect, void); @@ -383,6 +384,29 @@ mlirOperationStateAddAttributes(MlirOperationState *state, intptr_t n, MLIR_CAPI_EXPORTED void mlirOperationStateEnableResultTypeInference(MlirOperationState *state); +//===----------------------------------------------------------------------===// +// AsmState API. +// While many of these are simple settings that could be represented in a +// struct, they are wrapped in a heap allocated object and accessed via +// functions to maximize the possibility of compatibility over time. +//===----------------------------------------------------------------------===// + +/// Creates new AsmState, as with AsmState the IR should not be mutated +/// in-between using this state. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForOperation(MlirOperation op, MlirOpPrintingFlags flags); + +/// Creates new AsmState from value. +/// Must be freed with a call to mlirAsmStateDestroy(). +// TODO: This should be expanded to handle location & resouce map. +MLIR_CAPI_EXPORTED MlirAsmState +mlirAsmStateCreateForValue(MlirValue value, MlirOpPrintingFlags flags); + +/// Destroys printing flags created with mlirAsmStateCreate. +MLIR_CAPI_EXPORTED void mlirAsmStateDestroy(MlirAsmState state); + //===----------------------------------------------------------------------===// // Op Printing flags API. // While many of these are simple settings that could be represented in a @@ -815,7 +839,7 @@ mlirValuePrint(MlirValue value, MlirStringCallback callback, void *userData); /// Prints a value as an operand (i.e., the ValueID). MLIR_CAPI_EXPORTED void mlirValuePrintAsOperand(MlirValue value, - MlirOpPrintingFlags flags, + MlirAsmState state, MlirStringCallback callback, void *userData); diff --git a/mlir/include/mlir/CAPI/IR.h b/mlir/include/mlir/CAPI/IR.h index b8ccec896..1836cb0ac 100644 --- a/mlir/include/mlir/CAPI/IR.h +++ b/mlir/include/mlir/CAPI/IR.h @@ -21,6 +21,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/IR/Operation.h" +DEFINE_C_API_PTR_METHODS(MlirAsmState, mlir::AsmState) DEFINE_C_API_PTR_METHODS(MlirBytecodeWriterConfig, mlir::BytecodeWriterConfig) DEFINE_C_API_PTR_METHODS(MlirContext, mlir::MLIRContext) DEFINE_C_API_PTR_METHODS(MlirDialect, mlir::Dialect) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b06937bc2..af713547c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3430,9 +3430,11 @@ void mlir::python::populateIRCore(py::module &m) { MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); - mlirValuePrintAsOperand(self.get(), flags, printAccum.getCallback(), + MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags); + mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(), printAccum.getUserData()); mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(state); return printAccum.join(); }, py::arg("use_local_scope") = false, kGetNameAsOperand) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ef234a912..7f5c2aaee 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -138,6 +138,51 @@ void mlirDialectRegistryDestroy(MlirDialectRegistry registry) { delete unwrap(registry); } +//===----------------------------------------------------------------------===// +// AsmState API. +//===----------------------------------------------------------------------===// + +MlirAsmState mlirAsmStateCreateForOperation(MlirOperation op, + MlirOpPrintingFlags flags) { + return wrap(new AsmState(unwrap(op), *unwrap(flags))); +} + +static Operation *findParent(Operation *op, bool shouldUseLocalScope) { + do { + // If we are printing local scope, stop at the first operation that is + // isolated from above. + if (shouldUseLocalScope && op->hasTrait()) + break; + + // Otherwise, traverse up to the next parent. + Operation *parentOp = op->getParentOp(); + if (!parentOp) + break; + op = parentOp; + } while (true); + return op; +} + +MlirAsmState mlirAsmStateCreateForValue(MlirValue value, + MlirOpPrintingFlags flags) { + Operation *op; + mlir::Value val = unwrap(value); + if (auto result = llvm::dyn_cast(val)) { + op = result.getOwner(); + } else { + op = llvm::cast(val).getOwner()->getParentOp(); + if (!op) { + emitError(val.getLoc()) << "<>"; + return {nullptr}; + } + } + op = findParent(op, unwrap(flags)->shouldUseLocalScope()); + return wrap(new AsmState(op, *unwrap(flags))); +} + +/// Destroys printing flags created with mlirAsmStateCreate. +void mlirAsmStateDestroy(MlirAsmState state) { delete unwrap(state); } + //===----------------------------------------------------------------------===// // Printing flags API. //===----------------------------------------------------------------------===// @@ -840,11 +885,11 @@ void mlirValuePrint(MlirValue value, MlirStringCallback callback, unwrap(value).print(stream); } -void mlirValuePrintAsOperand(MlirValue value, MlirOpPrintingFlags flags, +void mlirValuePrintAsOperand(MlirValue value, MlirAsmState state, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); Value cppValue = unwrap(value); - cppValue.printAsOperand(stream, *unwrap(flags)); + cppValue.printAsOperand(stream, *unwrap(state)); } MlirOpOperand mlirValueGetFirstUse(MlirValue value) { From 0a29b94dd8d1887ce971bbeef79062c07a3ae32d Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Mon, 18 Sep 2023 21:30:41 -0700 Subject: [PATCH 574/915] [mlir] Quality of life improvements to python API types. (#66723) * Moves several orphaned methods from Operation/OpView -> _OperationBase so that both hierarchies share them (whether unknown or known to ODS). * Adds typing information for missing `MLIRError` exception. * Adds `DiagnosticInfo` typing. * Adds `DenseResourceElementsAttr` typing that was missing. --- mlir/lib/Bindings/Python/IRCore.cpp | 67 ++++++++++--------- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 37 ++++++++-- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 2 +- 3 files changed, 67 insertions(+), 39 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index af713547c..504ed8f3e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2768,6 +2768,24 @@ void mlir::python::populateIRCore(py::module &m) { return PyOpAttributeMap( self.getOperation().getRef()); }) + .def_property_readonly( + "context", + [](PyOperationBase &self) { + PyOperation &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + return concreteOperation.getContext().getObject(); + }, + "Context that owns the Operation") + .def_property_readonly("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = + concreteOperation.get(); + MlirStringRef name = mlirIdentifierStr( + mlirOperationGetName(operation)); + return py::str(name.data, name.length); + }) .def_property_readonly("operands", [](PyOperationBase &self) { return PyOpOperandList( @@ -2813,6 +2831,14 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the source location the operation was defined or derived " "from.") + .def_property_readonly("parent", + [](PyOperationBase &self) -> py::object { + auto parent = + self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return py::none(); + }) .def( "__str__", [](PyOperationBase &self) { @@ -2855,6 +2881,12 @@ void mlir::python::populateIRCore(py::module &m) { .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), "Puts self immediately before the other operation in its parent " "block.") + .def( + "clone", + [](PyOperationBase &self, py::object ip) { + return self.getOperation().clone(ip); + }, + py::arg("ip") = py::none()) .def( "detach_from_parent", [](PyOperationBase &self) { @@ -2866,7 +2898,8 @@ void mlir::python::populateIRCore(py::module &m) { operation.detachFromParent(); return operation.createOpView(); }, - "Detaches the operation from its parent block."); + "Detaches the operation from its parent block.") + .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), @@ -2887,45 +2920,17 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("context") = py::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_property_readonly("parent", - [](PyOperation &self) -> py::object { - auto parent = self.getParentOperation(); - if (parent) - return parent->getObject(); - return py::none(); - }) - .def("erase", &PyOperation::erase) - .def("clone", &PyOperation::clone, py::arg("ip") = py::none()) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_property_readonly("name", - [](PyOperation &self) { - self.checkValid(); - MlirOperation operation = self.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) - .def_property_readonly( - "context", - [](PyOperation &self) { - self.checkValid(); - return self.getContext().getObject(); - }, - "Context that owns the Operation") + .def_property_readonly("operation", [](py::object self) { return self; }) .def_property_readonly("opview", &PyOperation::createOpView); auto opViewClass = py::class_(m, "OpView", py::module_local()) .def(py::init(), py::arg("operation")) .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly( - "context", - [](PyOpView &self) { - return self.getOperation().getContext().getObject(); - }, - "Context that owns the Operation") + .def_property_readonly("opview", [](py::object self) { return self; }) .def("__str__", [](PyOpView &self) { return py::str(self.getOperationObject()); }); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 23f4687d0..e8f4440d2 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -43,11 +43,13 @@ __all__ = [ "DenseElementsAttr", "DenseFPElementsAttr", "DenseIntElementsAttr", + "DenseResourceElementsAttr", "Dialect", "DialectDescriptor", "Dialects", "Diagnostic", "DiagnosticHandler", + "DiagnosticInfo", "DiagnosticSeverity", "DictAttr", "Float8E4M3FNType", @@ -74,6 +76,7 @@ __all__ = [ "Location", "MemRefType", "Module", + "MLIRError", "NamedAttribute", "NoneType", "OpaqueType", @@ -123,10 +126,16 @@ class _OperationBase: @property def attributes(self) -> OpAttributeMap: ... @property + def context(self) -> Context: ... + @property def location(self) -> Location: ... @property + def name(self) -> str: ... + @property def operands(self) -> OpOperandList: ... @property + @property + def parent(self) -> Optional[_OperationBase]: ... def regions(self) -> RegionSequence: ... @property def result(self) -> OpResult: ... @@ -530,6 +539,10 @@ class DenseIntElementsAttr(DenseElementsAttr): @property def type(self) -> Type: ... +class DenseResourceElementsAttr(Attribute): + @staticmethod + def get_from_buffer(array: Any, name: str, type: Type, alignment: Optional[int] = None, is_mutable: bool = False, context: Optional[Context] = None) -> None: ... + class Dialect: def __init__(self, descriptor: DialectDescriptor) -> None: ... @property @@ -563,6 +576,17 @@ class DiagnosticHandler: def __enter__(self) -> DiagnosticHandler: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... +class DiagnosticInfo: + def __init__(self, diag: Diagnostic) -> None: ... + @property + def severity(self) -> "DiagnosticSeverity": ... + @property + def location(self) -> "Location": ... + @property + def message(self) -> str: ... + @property + def notes(self) -> Sequence["DiagnosticInfo"]: ... + class DiagnosticSeverity: ERROR: DiagnosticSeverity WARNING: DiagnosticSeverity @@ -871,6 +895,9 @@ class Module: @property def operation(self) -> Operation: ... +class MLIRError(Exception): + def __init__(self, message: str, error_diagnostics: List[DiagnosticInfo]) -> None: ... + class NamedAttribute: @property def attr(self) -> Attribute: ... @@ -950,9 +977,9 @@ class OpView(_OperationBase): loc: Optional[Location] = None, ip: Optional[InsertionPoint] = None) -> _TOperation: ... @property - def context(self) -> Context: ... - @property def operation(self) -> Operation: ... + @property + def opview(self) -> "OpView": ... class Operation(_OperationBase): def _CAPICreate(self) -> object: ... @@ -968,13 +995,9 @@ class Operation(_OperationBase): @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... - @property - def name(self) -> str: ... + def operation(self) -> "Operation": ... @property def opview(self) -> OpView: ... - @property - def parent(self) -> Optional[_OperationBase]: ... class OperationIterator: def __iter__(self) -> OperationIterator: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 44d22255e..c072d5e0f 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -20,6 +20,6 @@ class PassManager: def enable_verifier(self, enable: bool) -> None: ... @staticmethod def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... - def run(self, module: _ir.Module) -> None: ... + def run(self, module: _ir._OperationBase) -> None: ... @property def _CAPIPtr(self) -> object: ... From 585eb11995125ef8ee5d90df941aa63c4633aa4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Tue, 19 Sep 2023 10:34:47 +0200 Subject: [PATCH 575/915] [mlir][linalg][transform][python] Allow no args in MaskedVectorize. (#66541) The mix-in of this op did not allow to pass in no argument. This special case is now handled correctly and covered by the tests. --- .../mlir/dialects/_structured_transform_ops_ext.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index c5134b6e7..fd3dbca7c 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -366,7 +366,7 @@ class MaskedVectorizeOp: def __init__( self, target: Union[Operation, OpView, Value], - vector_sizes: Union[DynamicIndexList, ArrayAttr], + vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, *, vectorize_nd_extract: Optional[bool] = None, scalable_sizes: OptionalBoolList = None, @@ -374,7 +374,13 @@ def __init__( loc=None, ip=None, ): - if scalable_sizes is None and static_vector_sizes is None: + if ( + scalable_sizes is None + and static_vector_sizes is None + and vector_sizes is None + ): + dynamic_vector_sizes = [] + elif scalable_sizes is None and static_vector_sizes is None: ( dynamic_vector_sizes, static_vector_sizes, From ba264b1e503c5b30addaa3d44c9b517a8bf7f825 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 20 Sep 2023 16:57:20 +0200 Subject: [PATCH 576/915] [mlir][Vector] Add fastmath flags to vector.reduction (#66905) This revision pipes the fastmath attribute support through the vector.reduction op. This seemingly simple first step already requires quite some genuflexions, file and builder reorganization. In the process, retire the boolean reassoc flag deep in the LLVM dialect builders and just use the fastmath attribute. During conversions, templated builders for predicated intrinsics are partially cleaned up. In the future, to finalize the cleanups, one should consider adding fastmath to the VPIntrinsic ops. --- mlir/python/mlir/dialects/Vector.td | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 mlir/python/mlir/dialects/Vector.td diff --git a/mlir/python/mlir/dialects/Vector.td b/mlir/python/mlir/dialects/Vector.td new file mode 100644 index 000000000..f659f754b --- /dev/null +++ b/mlir/python/mlir/dialects/Vector.td @@ -0,0 +1,14 @@ +//===-- Vector.td - Entry point for Vector bindings --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR +#define PYTHON_BINDINGS_VECTOR + +include "mlir/Dialect/Vector/IR/Vector.td" + +#endif // PYTHON_BINDINGS_VECTOR From e6aa575fd1d36de59e72a736712d578da34d416b Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 20 Sep 2023 18:53:29 +0200 Subject: [PATCH 577/915] [mlir] regenerate linalg named ops yaml (#65475) The Linalg named ops specification went out of sync with the OpDSL description, presumably due to direct manual modifications of the yaml file. Additionally, the unsigned division has been generating the signed scalar instruction, which is now fixed. --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index dee0c3e3f..6eae3d916 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -217,7 +217,7 @@ def div_unsigned( a `linalg.broadcast` + `linalg.div` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ - O[None] = lhs[None] / rhs[None] + O[None] = BinaryFn.div_unsigned(lhs[None], rhs[None]) @linalg_structured_op From 93c337a94b6f3b316fdcb9917836fafc1304fc29 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 20 Sep 2023 15:12:06 -0700 Subject: [PATCH 578/915] [mlir][python] Expose AsmState python side. (#66819) This does basic plumbing, ideally want a context approach to reduce needing to thread these manually, but the current is useful even in that state. Made Value.get_name change backwards compatible, so one could either set a field or create a state to pass in. --- mlir/lib/Bindings/Python/IRCore.cpp | 38 ++++++++++++++++++++++------- mlir/lib/Bindings/Python/IRModule.h | 25 +++++++++++++++++++ 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 504ed8f3e..c74b37a51 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3430,19 +3430,35 @@ void mlir::python::populateIRCore(py::module &m) { kValueDunderStrDocstring) .def( "get_name", - [](PyValue &self, bool useLocalScope) { + [](PyValue &self, std::optional useLocalScope, + std::optional> state) { PyPrintAccumulator printAccum; - MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (useLocalScope) - mlirOpPrintingFlagsUseLocalScope(flags); - MlirAsmState state = mlirAsmStateCreateForValue(self.get(), flags); - mlirValuePrintAsOperand(self.get(), state, printAccum.getCallback(), + MlirOpPrintingFlags flags; + MlirAsmState valueState; + // Use state if provided, else create a new state. + if (state) { + valueState = state.value().get().get(); + // Don't allow setting using local scope and state at same time. + if (useLocalScope.has_value()) + throw py::value_error( + "setting AsmState and local scope together not supported"); + } else { + flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope.value_or(false)) + mlirOpPrintingFlagsUseLocalScope(flags); + valueState = mlirAsmStateCreateForValue(self.get(), flags); + } + mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); - mlirOpPrintingFlagsDestroy(flags); - mlirAsmStateDestroy(state); + // Release state if allocated locally. + if (!state) { + mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(valueState); + } return printAccum.join(); }, - py::arg("use_local_scope") = false, kGetNameAsOperand) + py::arg("use_local_scope") = std::nullopt, + py::arg("state") = std::nullopt, kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( @@ -3461,6 +3477,10 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); + py::class_(m, "AsmState", py::module_local()) + .def(py::init(), py::arg("value"), + py::arg("use_local_scope") = false); + //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d1911730c..23338f7fd 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -748,6 +748,31 @@ class PyRegion { MlirRegion region; }; +/// Wrapper around an MlirAsmState. +class PyAsmState { + public: + PyAsmState(MlirValue value, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = mlirAsmStateCreateForValue(value, flags); + } + ~PyAsmState() { + mlirOpPrintingFlagsDestroy(flags); + } + // Delete copy constructors. + PyAsmState(PyAsmState &other) = delete; + PyAsmState(const PyAsmState &other) = delete; + + MlirAsmState get() { return state; } + + private: + MlirAsmState state; + MlirOpPrintingFlags flags; +}; + /// Wrapper around an MlirBlock. /// Blocks are managed completely by their containing operation. Unlike the /// C++ API, the python API does not support detached blocks. From 1f105f46f2fbf75bbf572c0c58d72534a969d3f0 Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Wed, 20 Sep 2023 15:35:50 -0700 Subject: [PATCH 579/915] [mlir][sparse] Generates python bindings for SparseTensorTransformOps. (#66937) --- mlir/python/CMakeLists.txt | 9 +++++++++ .../mlir/dialects/SparseTensorTransformOps.td | 14 ++++++++++++++ .../mlir/dialects/transform/sparse_tensor.py | 5 +++++ 3 files changed, 28 insertions(+) create mode 100644 mlir/python/mlir/dialects/SparseTensorTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/sparse_tensor.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 5d2f233ca..25be18fce 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -213,6 +213,15 @@ declare_mlir_dialect_extension_python_bindings( "../../include/mlir/Dialect/Linalg/TransformOps/LinalgTransformEnums.td" ) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SparseTensorTransformOps.td + SOURCES + dialects/transform/sparse_tensor.py + DIALECT_NAME transform + EXTENSION_NAME sparse_tensor_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SparseTensorTransformOps.td b/mlir/python/mlir/dialects/SparseTensorTransformOps.td new file mode 100644 index 000000000..f4c4464ee --- /dev/null +++ b/mlir/python/mlir/dialects/SparseTensorTransformOps.td @@ -0,0 +1,14 @@ +//===-- SparseTensorTransformOps.td ------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS +#define PYTHON_BINDINGS_SPARSE_TENSOR_TRANSFORM_OPS + +include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/transform/sparse_tensor.py b/mlir/python/mlir/dialects/transform/sparse_tensor.py new file mode 100644 index 000000000..8b33270dc --- /dev/null +++ b/mlir/python/mlir/dialects/transform/sparse_tensor.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._sparse_tensor_transform_ops_gen import * From d7d52bdfe95353d6198637cd939fd1f8ec381711 Mon Sep 17 00:00:00 2001 From: Jevin Jiang Date: Thu, 21 Sep 2023 08:06:44 +0000 Subject: [PATCH 580/915] Fix induction variable type in scf.for py binding. - make sure that the type of induction variable should be determined by the type of the lower bound type. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D159534 --- mlir/python/mlir/dialects/_scf_ops_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index 4b2519ef3..4b0a31327 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -52,7 +52,7 @@ def __init__( ip=ip, ) ) - self.regions[0].blocks.append(IndexType.get(), *results) + self.regions[0].blocks.append(self.operands[0].type, *results) @property def body(self): From 6f74c849132a742c578dc47043106cf16b8f2723 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Thu, 21 Sep 2023 12:57:41 +0200 Subject: [PATCH 581/915] [mlir][python] smaller scope for vector enumgen (#66992) Don't generate enums from the main VectorOps.td file as that transitively includes enums from Arith. --------- Co-authored-by: Nicolas Vasilache --- mlir/python/CMakeLists.txt | 3 ++- mlir/python/mlir/dialects/VectorAttributes.td | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 mlir/python/mlir/dialects/VectorAttributes.td diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 25be18fce..9368cb4c2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -381,7 +381,8 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/VectorOps.td SOURCES dialects/vector.py DIALECT_NAME vector - GEN_ENUM_BINDINGS) + GEN_ENUM_BINDINGS_TD_FILE + "dialects/VectorAttributes.td") ################################################################################ # Python extensions. diff --git a/mlir/python/mlir/dialects/VectorAttributes.td b/mlir/python/mlir/dialects/VectorAttributes.td new file mode 100644 index 000000000..038e0ba21 --- /dev/null +++ b/mlir/python/mlir/dialects/VectorAttributes.td @@ -0,0 +1,14 @@ +//===-- VectorAttributes.td - Entry point for bindings -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD +#define PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD + +include "mlir/Dialect/Vector/IR/VectorAttributes.td" + +#endif // PYTHON_BINDINGS_VECTOR_ATTRDEFS_TD From edffe89b98b2a9adca5c72efbef182b60821bc52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 21 Sep 2023 15:38:29 +0200 Subject: [PATCH 582/915] [mlir][linalg][transform] Rename {masked_vectorize => vectorize => vectorize_children_and...}. (#66575) This PR renames the vectorization transform ops as follows: * `structured.masked_vectorize` => `structured.vectorize`. This reflects the fact that since [recently](https://reviews.llvm.org/D157774) the op can also handle the unmasked case. * `structured.vectorize` => `structured.vectorize_children_and_applies_patterns`. This reflects the fact that the op does not just vectorize the given payload op but all vectorizable children contained in it, and applies patterns before and after for preparation and clean-up. This rename was discussed first [here](https://reviews.llvm.org/D157774). The PR also adapts and cleans ups the tablegen description of the `VectorizeChildrenAndApplyPatternsOp` (formerly `VectorizeOp`). --- .../python/mlir/dialects/_structured_transform_ops_ext.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index fd3dbca7c..6273452c0 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -360,8 +360,8 @@ def __init__( ) -class MaskedVectorizeOp: - """Specialization for MaskedVectorizeOp class.""" +class VectorizeOp: + """Specialization for VectorizeOp class.""" def __init__( self, @@ -730,8 +730,8 @@ def __init__( ) -class VectorizeOp: - """Specialization for VectorizeOp class.""" +class VectorizeChildrenAndApplyPatternsOp: + """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" def __init__( self, From c0e557b0c89cda1b6cedae6b82de51e7b8a6fdb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Thu, 21 Sep 2023 18:17:00 +0200 Subject: [PATCH 583/915] [mlir][memref][transform] Add new alloca_to_global op. (#66511) This PR adds a new transform op that replaces `memref.alloca`s with `memref.get_global`s to newly inserted `memref.global`s. This is useful, for example, for allocations that should reside in the shared memory of a GPU, which have to be declared as globals. --- .../dialects/_memref_transform_ops_ext.py | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py index 4afe8e7b8..1cc00bdcb 100644 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py @@ -11,6 +11,52 @@ from typing import Optional, overload, Union +class MemRefAllocaToGlobalOp: + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None + ): + ... + + @overload + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None + ): + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + alloca = get_global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + alloca, + loc=loc, + ip=ip, + ) + + class MemRefMultiBufferOp: """Specialization for MemRefMultiBufferOp class.""" From 28117daebbc8998772990089e2e259b64ce4d01a Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 22 Sep 2023 15:51:25 -0700 Subject: [PATCH 584/915] [mlir][sparse] add lvlToDim field to sparse tensor encoding (#67194) Note the new surface syntax allows for defining a dimToLvl and lvlToDim map at once (where usually the latter can be inferred from the former, but not always). This revision adds storage for the latter, together with some intial boilerplate. The actual support (inference, validation, printing, etc.) is still TBD of course. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 6 ++++++ mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 1 + mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 ++++++++-- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index b2e4b96c6..fecbeaf6b 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -51,6 +51,7 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. +/// TODO: add a version that supplied lvlToDim when it cannot be inferred MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl, @@ -69,6 +70,11 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); MLIR_CAPI_EXPORTED MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr); +/// Returns the level-to-dimension mapping of the `sparse_tensor.encoding` +/// attribute. +MLIR_CAPI_EXPORTED MlirAffineMap +mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr); + /// Returns the position bitwidth of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr); diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 70805005f..3061e042c 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -43,6 +43,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { [](py::object cls, std::vector lvlTypes, std::optional dimToLvl, int posWidth, int crdWidth, MlirContext context) { + // TODO: provide dimToLvl return cls(mlirSparseTensorEncodingAttrGet( context, lvlTypes.size(), lvlTypes.data(), dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth, diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index e18da1027..bf3a4ad5e 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -54,14 +54,20 @@ mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); - return wrap(SparseTensorEncodingAttr::get( - unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), posWidth, crdWidth)); + mlir::AffineMap lvlToDim; // TODO: provide in API + return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, + unwrap(dimToLvl), lvlToDim, + posWidth, crdWidth)); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { return wrap(cast(unwrap(attr)).getDimToLvl()); } +MlirAffineMap mlirSparseTensorEncodingAttrGetLvlToDim(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getLvlToDim()); +} + intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { return cast(unwrap(attr)).getLvlRank(); } From eedfa53a162a0b7db18c58e346bcdfb5dbb30b4a Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 25 Sep 2023 12:25:08 -0700 Subject: [PATCH 585/915] [mlir][py] Enable AsmState overload for operation. --- mlir/lib/Bindings/Python/IRCore.cpp | 2 ++ mlir/lib/Bindings/Python/IRModule.h | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c74b37a51..aad74f511 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3479,6 +3479,8 @@ void mlir::python::populateIRCore(py::module &m) { py::class_(m, "AsmState", py::module_local()) .def(py::init(), py::arg("value"), + py::arg("use_local_scope") = false) + .def(py::init(), py::arg("op"), py::arg("use_local_scope") = false); //---------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 23338f7fd..3ca7dd851 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -759,6 +759,16 @@ class PyAsmState { mlirOpPrintingFlagsUseLocalScope(flags); state = mlirAsmStateCreateForValue(value, flags); } + + PyAsmState(PyOperationBase &operation, bool useLocalScope) { + flags = mlirOpPrintingFlagsCreate(); + // The OpPrintingFlags are not exposed Python side, create locally and + // associate lifetime with the state. + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + state = + mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); + } ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } From b4e2e83d8e8e5a251f1d402ff5b036f9b24399b7 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 26 Sep 2023 09:14:29 +0200 Subject: [PATCH 586/915] [mlir] cleanup of structured.tile* transform ops (#67320) Rename and restructure tiling-related transform ops from the structured extension to be more homogeneous. In particular, all ops now follow a consistent naming scheme: - `transform.structured.tile_using_for`; - `transform.structured.tile_using_forall`; - `transform.structured.tile_reduction_using_for`; - `transform.structured.tile_reduction_using_forall`. This drops the "_op" naming artifact from `tile_to_forall_op` that shouldn't have been included in the first place, consistently specifies the name of the control flow op to be produced for loops (instead of `tile_reduction_using_scf` since `scf.forall` also belongs to `scf`), and opts for the `using` connector to avoid ambiguity. The loops produced by tiling are now systematically placed as *trailing* results of the transform op. While this required changing 3 out of 4 ops (except for `tile_using_for`), this is the only choice that makes sense when producing multiple `scf.for` ops that can be associated with a variadic number of handles. This choice is also most consistent with *other* transform ops from the structured extension, in particular with fusion ops, that produce the structured op as the leading result and the loop as the trailing result. --- .../mlir/dialects/_structured_transform_ops_ext.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 6273452c0..3757a3d3b 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -567,8 +567,8 @@ def __init__( ) -class TileOp: - """Specialization for TileOp class.""" +class TileUsingForOp: + """Specialization for TileUsingForOp class.""" @overload def __init__( @@ -616,7 +616,9 @@ def __init__( if isinstance(loop_types_or_target, (Operation, Value, OpView)): loop_types = [transform.AnyOpType.get()] * num_loops target = loop_types_or_target - assert target_or_none is None, "Cannot construct TileOp with two targets." + assert ( + target_or_none is None + ), "Cannot construct TileUsingForOp with two targets." else: loop_types = ( ([loop_types_or_target] * num_loops) @@ -638,8 +640,8 @@ def __init__( ) -class TileToForallOp: - """Specialization for TileToForallOp class.""" +class TileUsingForallOp: + """Specialization for TileUsingForallOp class.""" @overload def __init__( From 034bc45d8a9749bd8c49ba4fb6849ba5af920844 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 26 Sep 2023 01:53:17 -0700 Subject: [PATCH 587/915] [MLIR] Introduce new C bindings to differentiate between discardable and inherent attributes (#66332) This is part of the transition toward properly splitting the two groups. This only introduces new C APIs, the Python bindings are unaffected. No API is removed. --- mlir/include/mlir-c/IR.h | 52 ++++++++++++++++++++++++++++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 47 ++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 68eccab6d..a6408317d 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -576,25 +576,77 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op); MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos); +/// Returns true if this operation defines an inherent attribute with this name. +/// Note: the attribute can be optional, so +/// `mlirOperationGetInherentAttributeByName` can still return a null attribute. +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Returns an inherent attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute +mlirOperationGetInherentAttributeByName(MlirOperation op, MlirStringRef name); + +/// Sets an inherent attribute by name, replacing the existing if it exists. +/// This has no effect if "name" does not match an inherent attribute. +MLIR_CAPI_EXPORTED void +mlirOperationSetInherentAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Returns the number of discardable attributes attached to the operation. +MLIR_CAPI_EXPORTED intptr_t +mlirOperationGetNumDiscardableAttributes(MlirOperation op); + +/// Return `pos`-th discardable attribute of the operation. +MLIR_CAPI_EXPORTED MlirNamedAttribute +mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos); + +/// Returns a discardable attribute attached to the operation given its name. +MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetDiscardableAttributeByName( + MlirOperation op, MlirStringRef name); + +/// Sets a discardable attribute by name, replacing the existing if it exists or +/// adding a new one otherwise. The new `attr` Attribute is not allowed to be +/// null, use `mlirOperationRemoveDiscardableAttributeByName` to remove an +/// Attribute instead. +MLIR_CAPI_EXPORTED void +mlirOperationSetDiscardableAttributeByName(MlirOperation op, MlirStringRef name, + MlirAttribute attr); + +/// Removes a discardable attribute by name. Returns false if the attribute was +/// not found and true if removed. +MLIR_CAPI_EXPORTED bool +mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name); + /// Returns the number of attributes attached to the operation. +/// Deprecated, please use `mlirOperationGetNumInherentAttributes` or +/// `mlirOperationGetNumDiscardableAttributes`. MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumAttributes(MlirOperation op); /// Return `pos`-th attribute of the operation. +/// Deprecated, please use `mlirOperationGetInherentAttribute` or +/// `mlirOperationGetDiscardableAttribute`. MLIR_CAPI_EXPORTED MlirNamedAttribute mlirOperationGetAttribute(MlirOperation op, intptr_t pos); /// Returns an attribute attached to the operation given its name. +/// Deprecated, please use `mlirOperationGetInherentAttributeByName` or +/// `mlirOperationGetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED MlirAttribute mlirOperationGetAttributeByName(MlirOperation op, MlirStringRef name); /// Sets an attribute by name, replacing the existing if it exists or /// adding a new one otherwise. +/// Deprecated, please use `mlirOperationSetInherentAttributeByName` or +/// `mlirOperationSetDiscardableAttributeByName`. MLIR_CAPI_EXPORTED void mlirOperationSetAttributeByName(MlirOperation op, MlirStringRef name, MlirAttribute attr); /// Removes an attribute by name. Returns false if the attribute was not found /// and true if removed. +/// Deprecated, please use `mlirOperationRemoveInherentAttributeByName` or +/// `mlirOperationRemoveDiscardableAttributeByName`. MLIR_CAPI_EXPORTED bool mlirOperationRemoveAttributeByName(MlirOperation op, MlirStringRef name); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 7f5c2aaee..04b386b82 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -595,6 +595,53 @@ MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos) { return wrap(unwrap(op)->getSuccessor(static_cast(pos))); } +MLIR_CAPI_EXPORTED bool +mlirOperationHasInherentAttributeByName(MlirOperation op, MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + return attr.has_value(); +} + +MlirAttribute mlirOperationGetInherentAttributeByName(MlirOperation op, + MlirStringRef name) { + std::optional attr = unwrap(op)->getInherentAttr(unwrap(name)); + if (attr.has_value()) + return wrap(*attr); + return {}; +} + +void mlirOperationSetInherentAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setInherentAttr( + StringAttr::get(unwrap(op)->getContext(), unwrap(name)), unwrap(attr)); +} + +intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { + return static_cast(unwrap(op)->getDiscardableAttrs().size()); +} + +MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, + intptr_t pos) { + NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos]; + return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; +} + +MlirAttribute mlirOperationGetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return wrap(unwrap(op)->getDiscardableAttr(unwrap(name))); +} + +void mlirOperationSetDiscardableAttributeByName(MlirOperation op, + MlirStringRef name, + MlirAttribute attr) { + unwrap(op)->setDiscardableAttr(unwrap(name), unwrap(attr)); +} + +bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, + MlirStringRef name) { + return !!unwrap(op)->removeDiscardableAttr(unwrap(name)); +} + intptr_t mlirOperationGetNumAttributes(MlirOperation op) { return static_cast(unwrap(op)->getAttrs().size()); } From 0e46893ec8a40c98218c055d20a6d3bcfd99256a Mon Sep 17 00:00:00 2001 From: martin-luecke Date: Tue, 26 Sep 2023 16:10:24 +0200 Subject: [PATCH 588/915] [mlir][python] Expose transform param types (#67421) This exposes the Transform dialect types `AnyParamType` and `ParamType` via the Python bindings. --- mlir/include/mlir-c/Dialect/Transform.h | 19 ++++++++++ mlir/lib/Bindings/Python/DialectTransform.cpp | 35 +++++++++++++++++++ mlir/lib/CAPI/Dialect/Transform.cpp | 28 +++++++++++++++ 3 files changed, 82 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h index 954575925..91c99b1f8 100644 --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -27,6 +27,14 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); +//===---------------------------------------------------------------------===// +// AnyParamType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); + //===---------------------------------------------------------------------===// // AnyValueType //===---------------------------------------------------------------------===// @@ -49,6 +57,17 @@ mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef operationName); MLIR_CAPI_EXPORTED MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type); +//===---------------------------------------------------------------------===// +// ParamType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, + MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGetType(MlirType type); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index 932e40220..cbbf8332b 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -31,6 +31,20 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { "Get an instance of AnyOpType in the given context.", py::arg("cls"), py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// + // AnyParamType + //===-------------------------------------------------------------------===// + + auto anyParamType = + mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType); + anyParamType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirTransformAnyParamTypeGet(ctx)); + }, + "Get an instance of AnyParamType in the given context.", py::arg("cls"), + py::arg("context") = py::none()); + //===-------------------------------------------------------------------===// // AnyValueType //===-------------------------------------------------------------------===// @@ -71,6 +85,27 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { return py::str(operationName.data, operationName.length); }, "Get the name of the payload operation accepted by the handle."); + + //===-------------------------------------------------------------------===// + // ParamType + //===-------------------------------------------------------------------===// + + auto paramType = + mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType); + paramType.def_classmethod( + "get", + [](py::object cls, MlirType type, MlirContext ctx) { + return cls(mlirTransformParamTypeGet(ctx, type)); + }, + "Get an instance of ParamType for the given type in the given context.", + py::arg("cls"), py::arg("type"), py::arg("context") = py::none()); + paramType.def_property_readonly( + "type", + [](MlirType type) { + MlirType paramType = mlirTransformParamTypeGetType(type); + return paramType; + }, + "Get the type this ParamType is associated with."); } PYBIND11_MODULE(_mlirDialectsTransform, m) { diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 5841f6783..3f7f8b8e2 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -29,6 +29,18 @@ MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { return wrap(transform::AnyOpType::get(unwrap(ctx))); } +//===---------------------------------------------------------------------===// +// AnyParamType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformAnyParamType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) { + return wrap(transform::AnyParamType::get(unwrap(ctx))); +} + //===---------------------------------------------------------------------===// // AnyValueType //===---------------------------------------------------------------------===// @@ -62,3 +74,19 @@ MlirType mlirTransformOperationTypeGet(MlirContext ctx, MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { return wrap(cast(unwrap(type)).getOperationName()); } + +//===---------------------------------------------------------------------===// +// AnyOpType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsATransformParamType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) { + return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type))); +} + +MlirType mlirTransformParamTypeGetType(MlirType type) { + return wrap(cast(unwrap(type)).getType()); +} From a18412d0e368c929a1c3e20a9fa90221cb11e744 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 26 Sep 2023 11:44:37 -0700 Subject: [PATCH 589/915] Simplify diagnostic error management for MLIR properties API (NFC) (#67409) This is a follow-up to 0b42f5962596 which lazy-initialized the diagnostic and removed the need to dynamically abandon() an InFlightDiagnostic. This further simplifies the code to not needed to return a reference to an InFlightDiagnostic and instead eagerly emit errors. Also use `emitError` as name instead of `getDiag` which seems more explicit and in-line with the common usage. --- mlir/lib/CAPI/IR/IR.cpp | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 04b386b82..65b2b7466 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -416,18 +416,13 @@ static LogicalResult inferOperationTypes(OperationState &state) { auto prop = std::make_unique(info->getOpPropertyByteSize()); properties = OpaqueProperties(prop.get()); if (properties) { - std::unique_ptr diagnostic; - auto getDiag = [&]() -> InFlightDiagnostic & { - if (!diagnostic) { - diagnostic = std::make_unique( - emitError(state.location) - << " failed properties conversion while building " - << state.name.getStringRef() << " with `" << attributes << "`: "); - } - return *diagnostic; + auto emitError = [&]() { + return mlir::emitError(state.location) + << " failed properties conversion while building " + << state.name.getStringRef() << " with `" << attributes << "`: "; }; if (failed(info->setOpPropertiesFromAttribute(state.name, properties, - attributes, getDiag))) + attributes, emitError))) return failure(); } if (succeeded(inferInterface->inferReturnTypes( From 236295371516cb2261b6dce84ec3676221cc46e7 Mon Sep 17 00:00:00 2001 From: Andrzej Warzynski Date: Tue, 26 Sep 2023 19:57:30 +0000 Subject: [PATCH 590/915] [mlir][transform] Update transform.loop.peel (reland #67482) This patch updates `transform.loop.peel` so that this Op returns two rather than one handle: * one for the peeled loop, and * one for the remainder loop. Also, following this change this Op will fail if peeling fails. This is consistent with other similar Ops that also fail if no transformation takes place. Relands #67482 with an extra fix for transform_loop_ext.py --- mlir/python/mlir/dialects/_loop_transform_ops_ext.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py index 3536d45ab..1cdb2b9e7 100644 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py @@ -66,7 +66,8 @@ class LoopPeelOp: def __init__( self, - result_type: Type, + main_loop_type: Type, + remainder_loop_type: Type, target: Union[Operation, Value], *, fail_if_already_divisible: Union[bool, BoolAttr] = False, @@ -74,7 +75,8 @@ def __init__( loc=None, ): super().__init__( - result_type, + main_loop_type, + remainder_loop_type, _get_op_result_or_value(target), fail_if_already_divisible=( fail_if_already_divisible From bd5511bca5a22971c2139fb5d1f40262fe1cf682 Mon Sep 17 00:00:00 2001 From: Yinying Li <107574043+yinying-lisa-li@users.noreply.github.com> Date: Mon, 2 Oct 2023 15:06:40 +0000 Subject: [PATCH 591/915] [mlir][sparse] Update Enum name for CompressedWithHigh (#67845) Change CompressedWithHigh to LooseCompressed. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +++++++++---------- .../Bindings/Python/DialectSparseTensor.cpp | 14 +++++----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index fecbeaf6b..7e47e54e7 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -26,20 +26,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// If updating, keep them in sync and update the static_assert in the impl /// file. enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI = 32, // 0b01000_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU = 33, // 0b01000_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO = 34, // 0b01000_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO = 35, // 0b01000_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00 + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01 + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10 + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11 + MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 3061e042c..8e9e0b6ba 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -28,13 +28,13 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .value("singleton_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) .value("singleton_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) .value("singleton_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) - .value("compressed_hi", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI) - .value("compressed_hi_nu", - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU) - .value("compressed_hi_no", - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NO) - .value("compressed_hi_nu_no", - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_WITH_HI_NU_NO); + .value("loose_compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED) + .value("loose_compressed_nu", + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU) + .value("loose_compressed_no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO) + .value("loose_compressed_nu_no", + MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) From ff2e01bbde3211e64b3674d058c8dadd440b31dd Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 2 Oct 2023 15:37:25 -0500 Subject: [PATCH 592/915] [mlir][CAPI, python bindings] Expose `Operation::setSuccessor` (#67922) This is useful for emitting (using the python bindings) `cf.br` to blocks that are declared lexically post block creation. --- mlir/include/mlir-c/IR.h | 4 ++ mlir/lib/Bindings/Python/IRCore.cpp | 82 +++++++++++++++++++++++++---- mlir/lib/CAPI/IR/IR.cpp | 5 ++ 3 files changed, 80 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index a6408317d..e361f33a0 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -576,6 +576,10 @@ MLIR_CAPI_EXPORTED intptr_t mlirOperationGetNumSuccessors(MlirOperation op); MLIR_CAPI_EXPORTED MlirBlock mlirOperationGetSuccessor(MlirOperation op, intptr_t pos); +/// Set `pos`-th successor of the operation. +MLIR_CAPI_EXPORTED void +mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, MlirBlock block); + /// Returns true if this operation defines an inherent attribute with this name. /// Note: the attribute can be optional, so /// `mlirOperationGetInherentAttributeByName` can still return a null attribute. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index aad74f511..84f980d79 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2207,9 +2207,9 @@ class PyBlockArgumentList }; /// A list of operation operands. Internally, these are stored as consecutive -/// elements, random access is cheap. The result list is associated with the -/// operation whose results these are, and extends the lifetime of this -/// operation. +/// elements, random access is cheap. The (returned) operand list is associated +/// with the operation whose operands these are, and thus extends the lifetime +/// of this operation. class PyOpOperandList : public Sliceable { public: static constexpr const char *pyClassName = "OpOperandList"; @@ -2262,9 +2262,9 @@ class PyOpOperandList : public Sliceable { }; /// A list of operation results. Internally, these are stored as consecutive -/// elements, random access is cheap. The result list is associated with the -/// operation whose results these are, and extends the lifetime of this -/// operation. +/// elements, random access is cheap. The (returned) result list is associated +/// with the operation whose results these are, and thus extends the lifetime of +/// this operation. class PyOpResultList : public Sliceable { public: static constexpr const char *pyClassName = "OpResultList"; @@ -2307,6 +2307,52 @@ class PyOpResultList : public Sliceable { PyOperationRef operation; }; +/// A list of operation successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation whose successors these are, and thus extends +/// the lifetime of this operation. +class PyOpSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "OpSuccessors"; + + PyOpSuccessors(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumSuccessors(operation->get()) + : length, + step), + operation(operation) {} + + void dunderSetItem(intptr_t index, PyBlock block) { + index = wrapIndex(index); + mlirOperationSetSuccessor(operation->get(), index, block.get()); + } + + static void bindDerived(ClassTy &c) { + c.def("__setitem__", &PyOpSuccessors::dunderSetItem); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumSuccessors(operation->get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirOperationGetSuccessor(operation->get(), pos); + return PyBlock(operation, block); + } + + PyOpSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpSuccessors(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + /// A list of operation attributes. Can be indexed by name, producing /// attributes, or by index, producing named attributes. class PyOpAttributeMap { @@ -2924,16 +2970,28 @@ void mlir::python::populateIRCore(py::module &m) { &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_property_readonly("operation", [](py::object self) { return self; }) - .def_property_readonly("opview", &PyOperation::createOpView); + .def_property_readonly("opview", &PyOperation::createOpView) + .def_property_readonly( + "successors", + [](PyOperationBase &self) { + return PyOpSuccessors(self.getOperation().getRef()); + }, + "Returns the list of Operation successors."); auto opViewClass = py::class_(m, "OpView", py::module_local()) .def(py::init(), py::arg("operation")) .def_property_readonly("operation", &PyOpView::getOperationObject) .def_property_readonly("opview", [](py::object self) { return self; }) - .def("__str__", [](PyOpView &self) { - return py::str(self.getOperationObject()); - }); + .def( + "__str__", + [](PyOpView &self) { return py::str(self.getOperationObject()); }) + .def_property_readonly( + "successors", + [](PyOperationBase &self) { + return PyOpSuccessors(self.getOperation().getRef()); + }, + "Returns the list of Operation successors."); opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); @@ -3448,7 +3506,8 @@ void mlir::python::populateIRCore(py::module &m) { mlirOpPrintingFlagsUseLocalScope(flags); valueState = mlirAsmStateCreateForValue(self.get(), flags); } - mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), + mlirValuePrintAsOperand(self.get(), valueState, + printAccum.getCallback(), printAccum.getUserData()); // Release state if allocated locally. if (!state) { @@ -3523,6 +3582,7 @@ void mlir::python::populateIRCore(py::module &m) { PyOpOperandIterator::bind(m); PyOpOperandList::bind(m); PyOpResultList::bind(m); + PyOpSuccessors::bind(m); PyRegionIterator::bind(m); PyRegionList::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 65b2b7466..c1abbbe36 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -637,6 +637,11 @@ bool mlirOperationRemoveDiscardableAttributeByName(MlirOperation op, return !!unwrap(op)->removeDiscardableAttr(unwrap(name)); } +void mlirOperationSetSuccessor(MlirOperation op, intptr_t pos, + MlirBlock block) { + unwrap(op)->setSuccessor(unwrap(block), static_cast(pos)); +} + intptr_t mlirOperationGetNumAttributes(MlirOperation op) { return static_cast(unwrap(op)->getAttrs().size()); } From 1fb76fc1e1adfc37521086de3b1d45cbd63c10dd Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 2 Oct 2023 21:17:49 -0700 Subject: [PATCH 593/915] [mlir][py] Use overloads instead (NFC) Was using a local, pseudo overload rather than just using an overload proper. --- mlir/lib/Bindings/Python/IRCore.cpp | 44 +++++++++++++---------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 84f980d79..c8373e06f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3488,36 +3488,32 @@ void mlir::python::populateIRCore(py::module &m) { kValueDunderStrDocstring) .def( "get_name", - [](PyValue &self, std::optional useLocalScope, - std::optional> state) { + [](PyValue &self, bool useLocalScope) { PyPrintAccumulator printAccum; - MlirOpPrintingFlags flags; - MlirAsmState valueState; - // Use state if provided, else create a new state. - if (state) { - valueState = state.value().get().get(); - // Don't allow setting using local scope and state at same time. - if (useLocalScope.has_value()) - throw py::value_error( - "setting AsmState and local scope together not supported"); - } else { - flags = mlirOpPrintingFlagsCreate(); - if (useLocalScope.value_or(false)) - mlirOpPrintingFlagsUseLocalScope(flags); - valueState = mlirAsmStateCreateForValue(self.get(), flags); - } + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (useLocalScope) + mlirOpPrintingFlagsUseLocalScope(flags); + MlirAsmState valueState = + mlirAsmStateCreateForValue(self.get(), flags); + mlirValuePrintAsOperand(self.get(), valueState, + printAccum.getCallback(), + printAccum.getUserData()); + mlirOpPrintingFlagsDestroy(flags); + mlirAsmStateDestroy(valueState); + return printAccum.join(); + }, + py::arg("use_local_scope") = false) + .def( + "get_name", + [](PyValue &self, std::reference_wrapper state) { + PyPrintAccumulator printAccum; + MlirAsmState valueState = state.get().get(); mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); - // Release state if allocated locally. - if (!state) { - mlirOpPrintingFlagsDestroy(flags); - mlirAsmStateDestroy(valueState); - } return printAccum.join(); }, - py::arg("use_local_scope") = std::nullopt, - py::arg("state") = std::nullopt, kGetNameAsOperand) + py::arg("state"), kGetNameAsOperand) .def_property_readonly( "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( From 1b7ca4397ebd8dfba3f2a7111ac3da201abe85bb Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 3 Oct 2023 14:52:52 +0200 Subject: [PATCH 594/915] [mlir] enable python bindings for nvgpu transforms (#68088) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Expose the autogenerated bindings. Co-authored-by: Martin Lücke --- mlir/python/CMakeLists.txt | 9 +++++++++ .../python/mlir/dialects/NVGPUTransformOps.td | 20 +++++++++++++++++++ mlir/python/mlir/dialects/transform/nvgpu.py | 5 +++++ 3 files changed, 34 insertions(+) create mode 100644 mlir/python/mlir/dialects/NVGPUTransformOps.td create mode 100644 mlir/python/mlir/dialects/transform/nvgpu.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 9368cb4c2..088d9a765 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -200,6 +200,15 @@ declare_mlir_dialect_extension_python_bindings( DIALECT_NAME transform EXTENSION_NAME memref_transform) +declare_mlir_dialect_extension_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/NVGPUTransformOps.td + SOURCES + dialects/transform/nvgpu.py + DIALECT_NAME transform + EXTENSION_NAME nvgpu_transform) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/NVGPUTransformOps.td b/mlir/python/mlir/dialects/NVGPUTransformOps.td new file mode 100644 index 000000000..1f504e322 --- /dev/null +++ b/mlir/python/mlir/dialects/NVGPUTransformOps.td @@ -0,0 +1,20 @@ +//===-- NVGPUTransformOps.td -------------------------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the Python bindings generator for the transform ops provided +// by the NVGPU dialect. +// +//===----------------------------------------------------------------------===// + + +#ifndef PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS +#define PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS + +include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.td" + +#endif // PYTHON_BINDINGS_NVGPU_TRANSFORM_OPS diff --git a/mlir/python/mlir/dialects/transform/nvgpu.py b/mlir/python/mlir/dialects/transform/nvgpu.py new file mode 100644 index 000000000..74ba4c9ae --- /dev/null +++ b/mlir/python/mlir/dialects/transform/nvgpu.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._nvgpu_transform_ops_gen import * From 3f5081f6b5db530a9b7fffef516c5643460bae85 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 4 Oct 2023 20:35:24 -0500 Subject: [PATCH 595/915] [mlir][python] Enable py312. (#68009) Python 3.12 has been released so why not support it. --- mlir/python/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 7671d3329..a596f8747 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.19.5, <=1.23.5 +numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>= 5.3.1, <=6.0 -dataclasses>=0.6, <=0.8 +PyYAML>=5.3.1, <=6.0.1 +dataclasses>=0.6, <=0.8 \ No newline at end of file From 568f8cf6ef99762615d0f7068bea1320cdf831ef Mon Sep 17 00:00:00 2001 From: Jack Frankland <30410009+FranklandJack@users.noreply.github.com> Date: Sat, 7 Oct 2023 01:10:39 +0100 Subject: [PATCH 596/915] [mlir][tosa][linalg] Apply direct tosa -> linalg Conv2D lowering (#68304) TOSA defines the filter channel ordering for 2D convolution operation `tosa.conv2d` as `[OC, KH, KW, IC]`. The LinAlg dialect supports `[F, H, W, C]` and `[H, W, C, F]` orderings via the `linalg.conv_2d_nhwc_fhwc` and `linalg.conv_2d_nhwc_hwcf` operations respectively. Where `F == OC`, `KH == H`, `KW == W` and `C == IC`. Currently `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_hwcf` meaning we need to insert a transposition operation to permute the filter channels before they can be passed as weights to the linalg op, that is `[F, H, W, C]` -> `[H, W, C, F]`. An analogous transformation needs to be applied to the quantized operation that lowers to `linalg.conv_2d_nhwc_hwcf_q`. This commit updates the TOSA->LinAlg lowering so that `tosa.conv2d` is lowered to `linalg.conv2d_nhwc_fhwc` removing the need for the introduction of a transposition operation and making the mapping 1-1. It also adds a `linalg.conv_2d_nhwc_fhwc_q` quantized operation to the LinAlg dialect so the same direct 1-1 mapping can be applied to the quantized variant. This commit does not add any new lit tests but repurposes the current TosaToLinalgNamed tests by removing the checks for transpositions and updating the targeted LinAlg operations from `linalg.conv2d_nhwc_hwcf` to linalg.conv2d_nhwc_fhwc`. Signed-off-by: Jack Frankland --- .../linalg/opdsl/ops/core_named_ops.py | 30 +++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 6eae3d916..a8f8f8e0f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -693,6 +693,36 @@ def conv_2d_nhwc_hwcf_q( ) * (TypeFn.cast_signed(U, K[D.kh, D.kw, D.c, D.f]) - TypeFn.cast_signed(U, KZp)) +@linalg_structured_op +def conv_2d_nhwc_fhwc_q( + I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C), + K=TensorDef(T2, S.F, S.KH, S.KW, S.C), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.F, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NHWC. + * Kernel: FHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.f, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.f] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) + + @linalg_structured_op def conv_2d_nchw_fchw( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), From cdf6a60cf0b886ba2141423215ed6f48e4774571 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 9 Oct 2023 14:16:28 -0700 Subject: [PATCH 597/915] [mlir][python] generate value builders (#68308) This PR adds the additional generation of what I'm calling "value builders" (a term I'm not married to) that look like this: ```python def empty(sizes, element_type, *, loc=None, ip=None): return get_result_or_results(tensor.EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip)) ``` which instantiates a `tensor.EmptyOp` and then immediately grabs the result (`OpResult`) and then returns that *instead of a handle to the op*. What's the point of adding these when `EmptyOp.result` already exists? My claim/feeling/intuition is that eDSL users are more comfortable with a value centric programming model (i.e., passing values as operands) as opposed to an operator instantiation programming model. Thus this change enables (or at least goes towards) the bindings supporting such a user and use case. For example, ```python i32 = IntegerType.get_signless(32) ... ten1 = tensor.empty((10, 10), i32) ten2 = tensor.empty((10, 10), i32) ten3 = arith.addi(ten1, ten2) ``` Note, in order to present a "pythonic" API and enable "pythonic" eDSLs, the generated identifiers (op names and operand names) are snake case instead of camel case and thus `llvm::convertToSnakeFromCamelCase` needed a small fix. Thus this PR is stacked on top of https://github.com/llvm/llvm-project/pull/68375. In addition, as a kind of victory lap, this PR adds a "rangefor" that looks and acts exactly like python's `range` but emits `scf.for`. --- mlir/python/mlir/dialects/_ods_common.py | 15 +++++++++ mlir/python/mlir/dialects/_scf_ops_ext.py | 5 +-- mlir/python/mlir/dialects/scf.py | 38 +++++++++++++++++++++++ 3 files changed, 56 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 7655629a5..895c32281 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -13,6 +13,7 @@ "get_default_loc_context", "get_op_result_or_value", "get_op_results_or_values", + "get_op_result_or_op_results", "segmented_accessor", ] @@ -167,3 +168,17 @@ def get_op_results_or_values( return arg.results else: return [get_op_result_or_value(element) for element in arg] + + +def get_op_result_or_op_results( + op: _Union[_cext.ir.OpView, _cext.ir.Operation], +) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: + if isinstance(op, _cext.ir.OpView): + op = op.operation + return ( + list(get_op_results_or_values(op)) + if len(op.results) > 1 + else get_op_result_or_value(op) + if len(op.results) > 0 + else op + ) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py index 4b0a31327..89cc8a198 100644 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ b/mlir/python/mlir/dialects/_scf_ops_ext.py @@ -7,7 +7,8 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Any, Optional, Sequence, Union +from typing import Optional, Sequence, Union + from ._ods_common import ( get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, @@ -25,7 +26,7 @@ def __init__( iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, *, loc=None, - ip=None + ip=None, ): """Creates an SCF `for` operation. diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 302a49d56..49685ca22 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -2,4 +2,42 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional, Sequence + from ._scf_ops_gen import * +from .arith import constant +from ..ir import * + + +def for_( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[Value]] = None, + *, + loc=None, + ip=None, +): + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + params = [start, stop, step] + for i, p in enumerate(params): + if isinstance(p, int): + p = constant(p) + elif isinstance(p, float): + raise ValueError(f"{p=} must be int.") + params[i] = p + + for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) + iv = for_op.induction_variable + iter_args = tuple(for_op.inner_iter_args) + with InsertionPoint(for_op.body): + if len(iter_args) > 1: + yield iv, iter_args + elif len(iter_args) == 1: + yield iv, iter_args[0] + else: + yield iv From b313f69db6bb37476522a9fb5b64b31ffb8d28f0 Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Wed, 11 Oct 2023 16:37:11 -0400 Subject: [PATCH 598/915] [mlir][python] python binding for the affine.store op (#68816) This PR creates the necessary files to support bindings for operations in the affine dialect. This is the first of many PRs which will progressively introduce affine.load, affine.for, etc operations. I would like to acknowledge the work by Nelli's author @makslevental : https://github.com/makslevental/nelli/blob/main/nelli/mlir/affine/affine.py which jump-starts the work. --- mlir/python/CMakeLists.txt | 10 ++++ mlir/python/mlir/dialects/AffineOps.td | 14 +++++ mlir/python/mlir/dialects/_affine_ops_ext.py | 56 ++++++++++++++++++++ mlir/python/mlir/dialects/affine.py | 5 ++ 4 files changed, 85 insertions(+) create mode 100644 mlir/python/mlir/dialects/AffineOps.td create mode 100644 mlir/python/mlir/dialects/_affine_ops_ext.py create mode 100644 mlir/python/mlir/dialects/affine.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 088d9a765..c7b3c283a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -46,6 +46,16 @@ declare_mlir_python_sources(MLIRPythonCAPI.HeaderSources # Dialect bindings ################################################################################ +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/AffineOps.td + SOURCES + dialects/affine.py + dialects/_affine_ops_ext.py + DIALECT_NAME affine + GEN_ENUM_BINDINGS) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/AffineOps.td b/mlir/python/mlir/dialects/AffineOps.td new file mode 100644 index 000000000..e12ffafb8 --- /dev/null +++ b/mlir/python/mlir/dialects/AffineOps.td @@ -0,0 +1,14 @@ +//===-- AffineOps.td - Entry point for Affine bindings -----*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_AFFINE_OPS +#define PYTHON_BINDINGS_AFFINE_OPS + +include "mlir/Dialect/Affine/IR/AffineOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py new file mode 100644 index 000000000..dc465ce7a --- /dev/null +++ b/mlir/python/mlir/dialects/_affine_ops_ext.py @@ -0,0 +1,56 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import get_op_result_or_value as _get_op_result_or_value + from ._ods_common import get_op_results_or_values as _get_op_results_or_values +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +class AffineStoreOp: + """Specialization for the Affine store operation.""" + + def __init__( + self, + value: Union[Operation, OpView, Value], + memref: Union[Operation, OpView, Value], + map: AffineMap=None, + *, + map_operands=None, + loc=None, + ip=None + ): + """Creates an affine store operation. + + - `value`: the value to store into the memref. + - `memref`: the buffer to store into. + - `map`: the affine map that maps the map_operands to the index of the + memref. + - `map_operands`: the list of arguments to substitute the dimensions, + then symbols in the affine map, in increasing order. + """ + map = map if map is not None else [] + map_operands = map_operands if map_operands is not None else [] + operands = [ + _get_op_result_or_value(value), + _get_op_result_or_value(memref), + *[_get_op_result_or_value(op) for op in map_operands] + ] + results = [] + attributes = {"map": AffineMapAttr.get(map)} + regions = None + _ods_successors = None + super().__init__(self.build_generic( + attributes=attributes, + results=results, + operands=operands, + successors=_ods_successors, + regions=regions, + loc=loc, + ip=ip + )) diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py new file mode 100644 index 000000000..8a2a64c7c --- /dev/null +++ b/mlir/python/mlir/dialects/affine.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._affine_ops_gen import * From c59f7c5b7305a1c308b996d2c865b7d92f1dea9f Mon Sep 17 00:00:00 2001 From: Yinying Li <107574043+yinying-lisa-li@users.noreply.github.com> Date: Tue, 17 Oct 2023 16:09:39 -0400 Subject: [PATCH 599/915] [mlir][sparse] Populate lvlToDim (#68937) Updates: 1. Infer lvlToDim from dimToLvl 2. Add more tests for block sparsity 3. Finish TODOs related to lvlToDim, including adding lvlToDim to python binding Verification of lvlToDim that user provides will be implemented in the next PR. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 3 +-- .../lib/Bindings/Python/DialectSparseTensor.cpp | 17 +++++++++++++---- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 7 +++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 7e47e54e7..859a4f0dd 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -51,11 +51,10 @@ MLIR_CAPI_EXPORTED bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. -/// TODO: add a version that supplied lvlToDim when it cannot be inferred MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl, - int posWidth, int crdWidth); + MlirAffineMap lvlTodim, int posWidth, int crdWidth); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED intptr_t diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 8e9e0b6ba..9bde3a443 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -41,16 +41,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .def_classmethod( "get", [](py::object cls, std::vector lvlTypes, - std::optional dimToLvl, int posWidth, int crdWidth, + std::optional dimToLvl, + std::optional lvlToDim, int posWidth, int crdWidth, MlirContext context) { - // TODO: provide dimToLvl return cls(mlirSparseTensorEncodingAttrGet( context, lvlTypes.size(), lvlTypes.data(), - dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, posWidth, + dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, + lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, crdWidth)); }, py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"), - py::arg("pos_width"), py::arg("crd_width"), + py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_property_readonly( @@ -71,6 +72,14 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return {}; return ret; }) + .def_property_readonly( + "lvl_to_dim", + [](MlirAttribute self) -> std::optional { + MlirAffineMap ret = mlirSparseTensorEncodingAttrGetLvlToDim(self); + if (mlirAffineMapIsNull(ret)) + return {}; + return ret; + }) .def_property_readonly("pos_width", mlirSparseTensorEncodingAttrGetPosWidth) .def_property_readonly("crd_width", diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index bf3a4ad5e..c3ad95527 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -48,15 +48,14 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, MlirSparseTensorDimLevelType const *lvlTypes, - MlirAffineMap dimToLvl, int posWidth, - int crdWidth) { + MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, + int posWidth, int crdWidth) { SmallVector cppLvlTypes; cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); - mlir::AffineMap lvlToDim; // TODO: provide in API return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, - unwrap(dimToLvl), lvlToDim, + unwrap(dimToLvl), unwrap(lvlToDim), posWidth, crdWidth)); } From f191c96d6219b11e723889f4bc938d23c6ead11c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Longeri?= Date: Wed, 18 Oct 2023 16:53:18 +0200 Subject: [PATCH 600/915] [mlir][python] Expose `PyInsertionPoint`'s reference operation (#69082) The reason I want this is that I am writing my own Python bindings and would like to use the insertion point from `PyThreadContextEntry::getDefaultInsertionPoint()` to call C++ functions that take an `OpBuilder` (I don't need to expose it in Python but it also seems appropriate). AFAICT, there is currently no way to translate a `PyInsertionPoint` into an `OpBuilder` because the operation is inaccessible. --- mlir/lib/Bindings/Python/IRCore.cpp | 13 ++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2 ++ 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c8373e06f..389a4621c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3207,7 +3207,18 @@ void mlir::python::populateIRCore(py::module &m) { "Inserts an operation.") .def_property_readonly( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, - "Returns the block that this InsertionPoint points to."); + "Returns the block that this InsertionPoint points to.") + .def_property_readonly( + "ref_operation", + [](PyInsertionPoint &self) -> py::object { + auto ref_operation = self.getRefOperation(); + if (ref_operation) + return ref_operation->getObject(); + return py::none(); + }, + "The reference operation before which new operations are " + "inserted, or None if the insertion point is at the end of " + "the block"); //---------------------------------------------------------------------------- // Mapping of PyAttribute. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 3ca7dd851..c5412e735 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -833,6 +833,7 @@ class PyInsertionPoint { const pybind11::object &excTb); PyBlock &getBlock() { return block; } + std::optional &getRefOperation() { return refOperation; } private: // Trampoline constructor that avoids null initializing members while diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index e8f4440d2..2609117dd 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -755,6 +755,8 @@ class InsertionPoint: def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... @property def block(self) -> Block: ... + @property + def ref_operation(self) -> Optional[_OperationBase]: ... # TODO: Auto-generated. Audit and fix. class IntegerAttr(Attribute): From c1e006220e89c8cb9d9aa2220bbb73b200109337 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 19 Oct 2023 16:20:14 -0500 Subject: [PATCH 601/915] [mlir][python] remove mixins (#68853) This PR replaces the mixin `OpView` extension mechanism with the standard inheritance mechanism. Why? Firstly, mixins are not very pythonic (inheritance is usually used for this), a little convoluted, and too "tight" (can only be used in the immediately adjacent `_ext.py`). Secondly, it (mixins) are now blocking are correct implementation of "value builders" (see [here](https://github.com/llvm/llvm-project/pull/68764)) where the problem becomes how to choose the correct base class that the value builder should call. This PR looks big/complicated but appearances are deceiving; 4 things were needed to make this work: 1. Drop `skipDefaultBuilders` in `OpPythonBindingGen::emitDefaultOpBuilders` 2. Former mixin extension classes are converted to inherit from the generated `OpView` instead of being "mixins" a. extension classes that simply were calling into an already generated `super().__init__` continue to do so b. (almost all) extension classes that were calling `self.build_generic` because of a lack of default builder being generated can now also just call `super().__init__` 3. To handle the [lone single use-case](https://sourcegraph.com/search?q=context%3Aglobal+select_opview_mixin&patternType=standard&sm=1&groupBy=repo) of `select_opview_mixin`, namely [linalg](https://github.com/llvm/llvm-project/blob/main/mlir/python/mlir/dialects/_linalg_ops_ext.py#L38), only a small change was necessary in `opdsl/lang/emitter.py` (thanks to the emission/generation of default builders/`__init__`s) 4. since the `extend_opview_class` decorator is removed, we need a way to register extension classes as the desired `OpView` that `op.opview` conjures into existence; so we do the standard thing and just enable replacing the existing registered `OpView` i.e., `register_operation(_Dialect, replace=True)`. Note, the upgrade path for the common case is to change an extension to inherit from the generated builder and decorate it with `register_operation(_Dialect, replace=True)`. In the slightly more complicated case where `super().__init(self.build_generic(...))` is called in the extension's `__init__`, this needs to be updated to call `__init__` in `OpView`, i.e., the grandparent (see updated docs). Note, also `_ext.py` files/modules will no longer be automatically loaded. Note, the PR has 3 base commits that look funny but this was done for the purpose of tracking the line history of moving the `_ops_ext.py` class into `.py` and updating (commit labeled "fix"). --- mlir/lib/Bindings/Python/Globals.h | 4 +- mlir/lib/Bindings/Python/IRModule.cpp | 4 +- mlir/lib/Bindings/Python/MainModule.cpp | 11 +- mlir/python/CMakeLists.txt | 19 - mlir/python/mlir/dialects/_affine_ops_ext.py | 56 -- mlir/python/mlir/dialects/_arith_ops_ext.py | 69 -- .../mlir/dialects/_bufferization_ops_ext.py | 41 - .../_bufferization_transform_ops_ext.py | 128 --- mlir/python/mlir/dialects/_builtin_ops_ext.py | 20 - mlir/python/mlir/dialects/_func_ops_ext.py | 319 -------- .../mlir/dialects/_gpu_transform_ops_ext.py | 124 --- mlir/python/mlir/dialects/_linalg_ops_ext.py | 47 -- .../mlir/dialects/_loop_transform_ops_ext.py | 134 --- mlir/python/mlir/dialects/_memref_ops_ext.py | 36 - .../dialects/_memref_transform_ops_ext.py | 114 --- .../mlir/dialects/_ml_program_ops_ext.py | 113 --- mlir/python/mlir/dialects/_ods_common.py | 59 -- mlir/python/mlir/dialects/_pdl_ops_ext.py | 271 ------ mlir/python/mlir/dialects/_scf_ops_ext.py | 107 --- .../dialects/_structured_transform_ops_ext.py | 759 ----------------- mlir/python/mlir/dialects/_tensor_ops_ext.py | 44 - .../dialects/_tensor_transform_ops_ext.py | 64 -- .../mlir/dialects/_transform_ops_ext.py | 176 ---- .../_transform_pdl_extension_ops_ext.py | 55 -- mlir/python/mlir/dialects/affine.py | 51 +- mlir/python/mlir/dialects/arith.py | 71 ++ mlir/python/mlir/dialects/bufferization.py | 36 + mlir/python/mlir/dialects/builtin.py | 20 + mlir/python/mlir/dialects/func.py | 323 ++++++++ .../dialects/linalg/opdsl/lang/emitter.py | 2 +- .../linalg/opdsl/ops/core_named_ops.py | 107 +-- mlir/python/mlir/dialects/memref.py | 38 + mlir/python/mlir/dialects/ml_program.py | 114 +++ mlir/python/mlir/dialects/pdl.py | 285 +++++++ mlir/python/mlir/dialects/python_test.py | 7 +- mlir/python/mlir/dialects/scf.py | 115 ++- mlir/python/mlir/dialects/tensor.py | 37 + .../mlir/dialects/transform/__init__.py | 170 ++++ .../mlir/dialects/transform/bufferization.py | 129 +++ mlir/python/mlir/dialects/transform/gpu.py | 125 +++ mlir/python/mlir/dialects/transform/loop.py | 140 ++++ mlir/python/mlir/dialects/transform/memref.py | 115 +++ mlir/python/mlir/dialects/transform/pdl.py | 50 ++ .../mlir/dialects/transform/structured.py | 773 ++++++++++++++++++ mlir/python/mlir/dialects/transform/tensor.py | 64 ++ mlir/python/mlir/runtime/np_to_memref.py | 8 +- 46 files changed, 2731 insertions(+), 2823 deletions(-) delete mode 100644 mlir/python/mlir/dialects/_affine_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_arith_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_bufferization_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_builtin_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_func_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_gpu_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_linalg_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_loop_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_memref_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_memref_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_ml_program_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_pdl_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_scf_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_structured_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_tensor_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_tensor_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_transform_ops_ext.py delete mode 100644 mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 97cd70089..21899bdce 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -77,10 +77,10 @@ class PyGlobals { pybind11::object pyClass); /// Adds a concrete implementation operation class. - /// Raises an exception if the mapping already exists. + /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass); + pybind11::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. std::optional diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 2cc66277a..a1c8ab7a0 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -96,9 +96,9 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass) { + py::object pyClass, bool replace) { py::object &found = operationClassMap[operationName]; - if (found) { + if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") .str()); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index cdddfbe50..a936becf6 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -41,7 +41,7 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, + "operation_name"_a, "operation_class"_a, "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -63,12 +63,13 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::object &dialectClass) -> py::cpp_function { + [](const py::object &dialectClass, bool replace) -> py::cpp_function { return py::cpp_function( - [dialectClass](py::object opClass) -> py::object { + [dialectClass, replace](py::object opClass) -> py::object { std::string operationName = opClass.attr("OPERATION_NAME").cast(); - PyGlobals::get().registerOperationImpl(operationName, opClass); + PyGlobals::get().registerOperationImpl(operationName, opClass, + replace); // Dict-stuff the new opClass by name onto the dialect class. py::object opClassName = opClass.attr("__name__"); @@ -76,7 +77,7 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "dialect_class"_a, + "dialect_class"_a, "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c7b3c283a..88e6e1360 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -52,7 +52,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/AffineOps.td SOURCES dialects/affine.py - dialects/_affine_ops_ext.py DIALECT_NAME affine GEN_ENUM_BINDINGS) @@ -78,7 +77,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/BufferizationOps.td SOURCES dialects/bufferization.py - dialects/_bufferization_ops_ext.py DIALECT_NAME bufferization GEN_ENUM_BINDINGS_TD_FILE "../../include/mlir/Dialect/Bufferization/IR/BufferizationEnums.td" @@ -90,7 +88,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/BuiltinOps.td SOURCES dialects/builtin.py - dialects/_builtin_ops_ext.py DIALECT_NAME builtin) declare_mlir_dialect_python_bindings( @@ -115,7 +112,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/FuncOps.td SOURCES dialects/func.py - dialects/_func_ops_ext.py DIALECT_NAME func) declare_mlir_dialect_python_bindings( @@ -131,7 +127,6 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/LinalgOps.td SOURCES - dialects/_linalg_ops_ext.py SOURCES_GLOB dialects/linalg/*.py DIALECT_NAME linalg @@ -152,7 +147,6 @@ ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TransformPDLExtensionOps.td SOURCES - dialects/_transform_pdl_extension_ops_ext.py dialects/transform/pdl.py DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) @@ -162,7 +156,6 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TransformOps.td SOURCES - dialects/_transform_ops_ext.py dialects/transform/__init__.py _mlir_libs/_mlir/dialects/transform/__init__.pyi DIALECT_NAME transform @@ -175,7 +168,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/BufferizationTransformOps.td SOURCES - dialects/_bufferization_transform_ops_ext.py dialects/transform/bufferization.py DIALECT_NAME transform EXTENSION_NAME bufferization_transform) @@ -185,7 +177,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/GPUTransformOps.td SOURCES - dialects/_gpu_transform_ops_ext.py dialects/transform/gpu.py DIALECT_NAME transform EXTENSION_NAME gpu_transform) @@ -195,7 +186,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/SCFLoopTransformOps.td SOURCES - dialects/_loop_transform_ops_ext.py dialects/transform/loop.py DIALECT_NAME transform EXTENSION_NAME loop_transform) @@ -205,7 +195,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/MemRefTransformOps.td SOURCES - dialects/_memref_transform_ops_ext.py dialects/transform/memref.py DIALECT_NAME transform EXTENSION_NAME memref_transform) @@ -224,7 +213,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/LinalgStructuredTransformOps.td SOURCES - dialects/_structured_transform_ops_ext.py dialects/transform/structured.py DIALECT_NAME transform EXTENSION_NAME structured_transform @@ -246,7 +234,6 @@ declare_mlir_dialect_extension_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/TensorTransformOps.td SOURCES - dialects/_tensor_transform_ops_ext.py dialects/transform/tensor.py DIALECT_NAME transform EXTENSION_NAME tensor_transform) @@ -276,7 +263,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/ArithOps.td SOURCES dialects/arith.py - dialects/_arith_ops_ext.py DIALECT_NAME arith GEN_ENUM_BINDINGS) @@ -286,7 +272,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/MemRefOps.td SOURCES dialects/memref.py - dialects/_memref_ops_ext.py DIALECT_NAME memref) declare_mlir_dialect_python_bindings( @@ -295,7 +280,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/MLProgramOps.td SOURCES dialects/ml_program.py - dialects/_ml_program_ops_ext.py DIALECT_NAME ml_program) declare_mlir_dialect_python_bindings( @@ -339,7 +323,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/PDLOps.td SOURCES dialects/pdl.py - dialects/_pdl_ops_ext.py _mlir_libs/_mlir/dialects/pdl.pyi DIALECT_NAME pdl) @@ -357,7 +340,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/SCFOps.td SOURCES dialects/scf.py - dialects/_scf_ops_ext.py DIALECT_NAME scf) declare_mlir_dialect_python_bindings( @@ -383,7 +365,6 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/TensorOps.td SOURCES dialects/tensor.py - dialects/_tensor_ops_ext.py DIALECT_NAME tensor) declare_mlir_dialect_python_bindings( diff --git a/mlir/python/mlir/dialects/_affine_ops_ext.py b/mlir/python/mlir/dialects/_affine_ops_ext.py deleted file mode 100644 index dc465ce7a..000000000 --- a/mlir/python/mlir/dialects/_affine_ops_ext.py +++ /dev/null @@ -1,56 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class AffineStoreOp: - """Specialization for the Affine store operation.""" - - def __init__( - self, - value: Union[Operation, OpView, Value], - memref: Union[Operation, OpView, Value], - map: AffineMap=None, - *, - map_operands=None, - loc=None, - ip=None - ): - """Creates an affine store operation. - - - `value`: the value to store into the memref. - - `memref`: the buffer to store into. - - `map`: the affine map that maps the map_operands to the index of the - memref. - - `map_operands`: the list of arguments to substitute the dimensions, - then symbols in the affine map, in increasing order. - """ - map = map if map is not None else [] - map_operands = map_operands if map_operands is not None else [] - operands = [ - _get_op_result_or_value(value), - _get_op_result_or_value(memref), - *[_get_op_result_or_value(op) for op in map_operands] - ] - results = [] - attributes = {"map": AffineMapAttr.get(map)} - regions = None - _ods_successors = None - super().__init__(self.build_generic( - attributes=attributes, - results=results, - operands=operands, - successors=_ods_successors, - regions=regions, - loc=loc, - ip=ip - )) diff --git a/mlir/python/mlir/dialects/_arith_ops_ext.py b/mlir/python/mlir/dialects/_arith_ops_ext.py deleted file mode 100644 index df38f8717..000000000 --- a/mlir/python/mlir/dialects/_arith_ops_ext.py +++ /dev/null @@ -1,69 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context - - from typing import Any, List, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -def _isa(obj: Any, cls: type): - try: - cls(obj) - except ValueError: - return False - return True - - -def _is_any_of(obj: Any, classes: List[type]): - return any(_isa(obj, cls) for cls in classes) - - -def _is_integer_like_type(type: Type): - return _is_any_of(type, [IntegerType, IndexType]) - - -def _is_float_type(type: Type): - return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) - - -class ConstantOp: - """Specialization for the constant op class.""" - - def __init__( - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None - ): - if isinstance(value, int): - super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, float): - super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) - else: - super().__init__(value, loc=loc, ip=ip) - - @classmethod - def create_index(cls, value: int, *, loc=None, ip=None): - """Create an index-typed constant.""" - return cls( - IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip - ) - - @property - def type(self): - return self.results[0].type - - @property - def value(self): - return Attribute(self.operation.attributes["value"]) - - @property - def literal_value(self) -> Union[int, float]: - if _is_integer_like_type(self.type): - return IntegerAttr(self.value).value - elif _is_float_type(self.type): - return FloatAttr(self.value).value - else: - raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/_bufferization_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_ops_ext.py deleted file mode 100644 index 1066cb4c7..000000000 --- a/mlir/python/mlir/dialects/_bufferization_ops_ext.py +++ /dev/null @@ -1,41 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - - from typing import Any, List, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -class AllocTensorOp: - """Extends the bufferization.alloc_tensor op.""" - - def __init__( - self, - tensor_type: Type, - dynamic_sizes: Sequence[Value], - copy: Value, - size_hint: Value, - escape: BoolAttr, - *, - loc=None, - ip=None - ): - """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" - context = get_default_loc_context(loc) - attributes = {} - if escape: - attributes["escape"] = escape - op = self.build_generic( - results=[tensor_type], - operands=[dynamic_sizes, copy, size_hint], - attributes=attributes, - loc=loc, - ip=ip, - ) - OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py b/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py deleted file mode 100644 index 7e6c1b81c..000000000 --- a/mlir/python/mlir/dialects/_bufferization_transform_ops_ext.py +++ /dev/null @@ -1,128 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from enum import Enum -from typing import Optional, overload, Union - - -class EmptyTensorToAllocTensorOp: - """Specialization for EmptyTensorToAllocTensorOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.OperationType.get("bufferization.alloc_tensor") - target = transformed_type_or_target - - super().__init__( - transformed_type, - target, - loc=loc, - ip=ip, - ) - - -class OneShotBufferizeOp: - """Specialization for OneShotBufferizeOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - allow_return_allocs_from_loops: Optional[bool] = None, - allow_unknown_ops: Optional[bool] = None, - bufferize_function_boundaries: Optional[bool] = None, - function_boundary_type_conversion: Optional[Enum] = None, - memcpy_op: Optional[str] = None, - print_conflicts: Optional[bool] = None, - test_analysis_only: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - - super().__init__( - transformed_type, - target, - allow_return_allocs_from_loops=allow_return_allocs_from_loops, - allow_unknown_ops=allow_unknown_ops, - bufferize_function_boundaries=bufferize_function_boundaries, - function_boundary_type_conversion=function_boundary_type_conversion, - memcpy_op=memcpy_op, - print_conflicts=print_conflicts, - test_analysis_only=test_analysis_only, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_builtin_ops_ext.py b/mlir/python/mlir/dialects/_builtin_ops_ext.py deleted file mode 100644 index 27a601230..000000000 --- a/mlir/python/mlir/dialects/_builtin_ops_ext.py +++ /dev/null @@ -1,20 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -class ModuleOp: - """Specialization for the module op class.""" - - def __init__(self, *, loc=None, ip=None): - super().__init__(self.build_generic(results=[], operands=[], loc=loc, ip=ip)) - body = self.regions[0].blocks.append() - - @property - def body(self): - return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/_func_ops_ext.py b/mlir/python/mlir/dialects/_func_ops_ext.py deleted file mode 100644 index 6d264c33f..000000000 --- a/mlir/python/mlir/dialects/_func_ops_ext.py +++ /dev/null @@ -1,319 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context - - import inspect - - from typing import Any, List, Optional, Sequence, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" -RESULT_ATTRIBUTE_NAME = "res_attrs" - - -class ConstantOp: - """Specialization for the constant op class.""" - - def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) - - @property - def type(self): - return self.results[0].type - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__( - self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None - ): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = ( - StringAttr.get(str(visibility)) if visibility is not None else None - ) - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError("External function does not have a body") - return self.regions[0].blocks[0] - - def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError("The function already has an entry block!") - self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context - ) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute - - @classmethod - def from_py_func( - FuncOp, - *inputs: Type, - results: Optional[Sequence[Type]] = None, - name: Optional[str] = None, - ): - """Decorator to define an MLIR FuncOp specified as a python function. - - Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are - active for the current thread (i.e. established in a `with` block). - - When applied as a decorator to a Python function, an entry block will - be constructed for the FuncOp with types as specified in `*inputs`. The - block arguments will be passed positionally to the Python function. In - addition, if the Python function accepts keyword arguments generally or - has a corresponding keyword argument, the following will be passed: - * `func_op`: The `func` op being defined. - - By default, the function name will be the Python function `__name__`. This - can be overriden by passing the `name` argument to the decorator. - - If `results` is not specified, then the decorator will implicitly - insert a `ReturnOp` with the `Value`'s returned from the decorated - function. It will also set the `FuncOp` type with the actual return - value types. If `results` is specified, then the decorated function - must return `None` and no implicit `ReturnOp` is added (nor are the result - types updated). The implicit behavior is intended for simple, single-block - cases, and users should specify result types explicitly for any complicated - cases. - - The decorated function can further be called from Python and will insert - a `CallOp` at the then-current insertion point, returning either None ( - if no return values), a unary Value (for one result), or a list of Values). - This mechanism cannot be used to emit recursive calls (by construction). - """ - - def decorator(f): - from . import func - - # Introspect the callable for optional features. - sig = inspect.signature(f) - has_arg_func_op = False - for param in sig.parameters.values(): - if param.kind == param.VAR_KEYWORD: - has_arg_func_op = True - if param.name == "func_op" and ( - param.kind == param.POSITIONAL_OR_KEYWORD - or param.kind == param.KEYWORD_ONLY - ): - has_arg_func_op = True - - # Emit the FuncOp. - implicit_return = results is None - symbol_name = name or f.__name__ - function_type = FunctionType.get( - inputs=inputs, results=[] if implicit_return else results - ) - func_op = FuncOp(name=symbol_name, type=function_type) - with InsertionPoint(func_op.add_entry_block()): - func_args = func_op.entry_block.arguments - func_kwargs = {} - if has_arg_func_op: - func_kwargs["func_op"] = func_op - return_values = f(*func_args, **func_kwargs) - if not implicit_return: - return_types = list(results) - assert return_values is None, ( - "Capturing a python function with explicit `results=` " - "requires that the wrapped function returns None." - ) - else: - # Coerce return values, add ReturnOp and rewrite func type. - if return_values is None: - return_values = [] - elif isinstance(return_values, tuple): - return_values = list(return_values) - elif isinstance(return_values, Value): - # Returning a single value is fine, coerce it into a list. - return_values = [return_values] - elif isinstance(return_values, OpView): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.operation.results - elif isinstance(return_values, Operation): - # Returning a single operation is fine, coerce its results a list. - return_values = return_values.results - else: - return_values = list(return_values) - func.ReturnOp(return_values) - # Recompute the function type. - return_types = [v.type for v in return_values] - function_type = FunctionType.get( - inputs=inputs, results=return_types - ) - func_op.attributes["function_type"] = TypeAttr.get(function_type) - - def emit_call_op(*call_args): - call_op = func.CallOp( - return_types, FlatSymbolRefAttr.get(symbol_name), call_args - ) - if return_types is None: - return None - elif len(return_types) == 1: - return call_op.result - else: - return call_op.results - - wrapped = emit_call_op - wrapped.__name__ = f.__name__ - wrapped.func_op = func_op - return wrapped - - return decorator - - -class CallOp: - """Specialization for the call op class.""" - - def __init__( - self, - calleeOrResults: Union[FuncOp, List[Type]], - argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], - arguments: Optional[List] = None, - *, - loc=None, - ip=None, - ): - """Creates an call operation. - - The constructor accepts three different forms: - - 1. A function op to be called followed by a list of arguments. - 2. A list of result types, followed by the name of the function to be - called as string, following by a list of arguments. - 3. A list of result types, followed by the name of the function to be - called as symbol reference attribute, followed by a list of arguments. - - For example - - f = func.FuncOp("foo", ...) - func.CallOp(f, [args]) - func.CallOp([result_types], "foo", [args]) - - In all cases, the location and insertion point may be specified as keyword - arguments if not provided by the surrounding context managers. - """ - - # TODO: consider supporting constructor "overloads", e.g., through a custom - # or pybind-provided metaclass. - if isinstance(calleeOrResults, FuncOp): - if not isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function, expected " - + "the second argument to be a list of call arguments, " - + f"got {type(argumentsOrCallee)}" - ) - if arguments is not None: - raise ValueError( - "unexpected third argument when constructing a call" - + "to a function" - ) - - super().__init__( - calleeOrResults.type.results, - FlatSymbolRefAttr.get( - calleeOrResults.name.value, context=_get_default_loc_context(loc) - ), - argumentsOrCallee, - loc=loc, - ip=ip, - ) - return - - if isinstance(argumentsOrCallee, list): - raise ValueError( - "when constructing a call to a function by name, " - + "expected the second argument to be a string or a " - + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" - ) - - if isinstance(argumentsOrCallee, FlatSymbolRefAttr): - super().__init__( - calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip - ) - elif isinstance(argumentsOrCallee, str): - super().__init__( - calleeOrResults, - FlatSymbolRefAttr.get( - argumentsOrCallee, context=_get_default_loc_context(loc) - ), - arguments, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py b/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py deleted file mode 100644 index ba72bac3a..000000000 --- a/mlir/python/mlir/dialects/_gpu_transform_ops_ext.py +++ /dev/null @@ -1,124 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union, overload - - -class MapForallToBlocks: - """Specialization for MapForallToBlocks class.""" - - @overload - def __init__( - self, - result_type: Type, - target: Union[Operation, OpView, Value], - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - result_type_or_target: Union[Operation, OpView, Type, Value], - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - grid_dims: Optional[Union[Sequence[int], Attribute]] = None, - generate_gpu_launch: Optional[Union[bool, Attribute]] = None, - loc=None, - ip=None - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_none - else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - - super().__init__( - result_type, - target, - grid_dims=grid_dims, - generate_gpu_launch=generate_gpu_launch, - loc=loc, - ip=ip, - ) - - -class MapNestedForallToThreads: - """Specialization for MapNestedForallToThreads class.""" - - @overload - def __init__( - self, - result_type: Type, - target: Union[Operation, OpView, Value], - *, - block_dims: Optional[Sequence[int]] = None, - warp_size: Optional[Sequence[int]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - block_dims: Optional[Sequence[int]] = None, - warp_size: Optional[Sequence[int]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - result_type_or_target: Union[Operation, OpView, Value, Type], - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - block_dims: Optional[Union[Sequence[int], Attribute]] = None, - warp_size: Optional[Union[Sequence[int], Attribute]] = None, - sync_after_distribute: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_none - else: - result_type = result_type_or_target.type - target = result_type_or_target - super().__init__( - result_type, - target, - block_dims=block_dims, - warp_size=warp_size, - sync_after_distribute=sync_after_distribute, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_linalg_ops_ext.py b/mlir/python/mlir/dialects/_linalg_ops_ext.py deleted file mode 100644 index 3f6d854ca..000000000 --- a/mlir/python/mlir/dialects/_linalg_ops_ext.py +++ /dev/null @@ -1,47 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Optional, Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context - from .._mlir_libs._mlirDialectsLinalg import fill_builtin_region -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from ._ods_common import get_op_result_or_value as _get_op_result_or_value - - -def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False - - -class StructuredOpMixin: - """All structured ops use the same mixin class.""" - - def __init__(self, inputs, outputs=(), results=(), loc=None, ip=None): - super().__init__( - self.build_generic( - results=list(results), - operands=[list(inputs), list(outputs)], - loc=loc, - ip=ip, - ) - ) - - -def select_opview_mixin(parent_opview_cls): - # TODO: This shouldn't be a heuristic: we should have a way to annotate - # the OpView to note that it is a structured op. - if ( - "__init__" not in parent_opview_cls.__dict__ - and hasattr(parent_opview_cls, "inputs") - and hasattr(parent_opview_cls, "outputs") - and hasattr(parent_opview_cls, "result_tensors") - ): - return StructuredOpMixin diff --git a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py b/mlir/python/mlir/dialects/_loop_transform_ops_ext.py deleted file mode 100644 index 1cdb2b9e7..000000000 --- a/mlir/python/mlir/dialects/_loop_transform_ops_ext.py +++ /dev/null @@ -1,134 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Union - - -class GetParentForOp: - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) - - -class LoopOutlineOp: - """Extension for LoopOutlineOp.""" - - def __init__( - self, - function_type: Type, - call_type: Type, - target: Union[Operation, Value], - *, - func_name: Union[str, StringAttr], - ip=None, - loc=None, - ): - super().__init__( - function_type, - call_type, - _get_op_result_or_value(target), - func_name=( - func_name - if isinstance(func_name, StringAttr) - else StringAttr.get(func_name) - ), - ip=ip, - loc=loc, - ) - - -class LoopPeelOp: - """Extension for LoopPeelOp.""" - - def __init__( - self, - main_loop_type: Type, - remainder_loop_type: Type, - target: Union[Operation, Value], - *, - fail_if_already_divisible: Union[bool, BoolAttr] = False, - ip=None, - loc=None, - ): - super().__init__( - main_loop_type, - remainder_loop_type, - _get_op_result_or_value(target), - fail_if_already_divisible=( - fail_if_already_divisible - if isinstance(fail_if_already_divisible, BoolAttr) - else BoolAttr.get(fail_if_already_divisible) - ), - ip=ip, - loc=loc, - ) - - -class LoopPipelineOp: - """Extension for LoopPipelineOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - iteration_interval: Optional[Union[int, IntegerAttr]] = None, - read_latency: Optional[Union[int, IntegerAttr]] = None, - ip=None, - loc=None, - ): - if iteration_interval is None: - iteration_interval = 1 - if read_latency is None: - read_latency = 10 - super().__init__( - result_type, - _get_op_result_or_value(target), - iteration_interval=iteration_interval, - read_latency=read_latency, - ip=ip, - loc=loc, - ) - - -class LoopUnrollOp: - """Extension for LoopUnrollOp.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - factor: Union[int, IntegerAttr], - ip=None, - loc=None, - ): - super().__init__( - _get_op_result_or_value(target), - factor=factor, - ip=ip, - loc=loc, - ) diff --git a/mlir/python/mlir/dialects/_memref_ops_ext.py b/mlir/python/mlir/dialects/_memref_ops_ext.py deleted file mode 100644 index 825f1a0a7..000000000 --- a/mlir/python/mlir/dialects/_memref_ops_ext.py +++ /dev/null @@ -1,36 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import get_op_result_or_value as _get_op_result_or_value - from ._ods_common import get_op_results_or_values as _get_op_results_or_values -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class LoadOp: - """Specialization for the MemRef load operation.""" - - def __init__( - self, - memref: Union[Operation, OpView, Value], - indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, - *, - loc=None, - ip=None - ): - """Creates a memref load operation. - - Args: - memref: the buffer to load from. - indices: the list of subscripts, may be empty for zero-dimensional - buffers. - loc: user-visible location of the operation. - ip: insertion point. - """ - indices_resolved = [] if indices is None else _get_op_results_or_values(indices) - super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py b/mlir/python/mlir/dialects/_memref_transform_ops_ext.py deleted file mode 100644 index 1cc00bdcb..000000000 --- a/mlir/python/mlir/dialects/_memref_transform_ops_ext.py +++ /dev/null @@ -1,114 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, overload, Union - - -class MemRefAllocaToGlobalOp: - """Specialization for MemRefAllocaToGlobalOp class.""" - - @overload - def __init__( - self, - get_global_type: Type, - global_type: Type, - alloca: Union[Operation, OpView, Value], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): - ... - - def __init__( - self, - get_global_type_or_alloca: Union[Operation, OpView, Type, Value], - global_type_or_none: Optional[Type] = None, - alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None - ): - if isinstance(get_global_type_or_alloca, Type): - get_global_type = get_global_type_or_alloca - global_type = global_type_or_none - alloca = alloca_or_none - else: - get_global_type = transform.AnyOpType.get() - global_type = transform.AnyOpType.get() - alloca = get_global_type_or_alloca - - super().__init__( - get_global_type, - global_type, - alloca, - loc=loc, - ip=ip, - ) - - -class MemRefMultiBufferOp: - """Specialization for MemRefMultiBufferOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - factor: Union[int, IntegerAttr], - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - factor: Union[int, IntegerAttr], - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, - factor_or_none: Optional[Union[int, IntegerAttr]] = None, - *, - skip_analysis: Optional[bool] = None, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_factor - factor = factor_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - factor = target_or_factor - - super().__init__( - transformed_type, - target, - factor, - skip_analysis=skip_analysis, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_ml_program_ops_ext.py b/mlir/python/mlir/dialects/_ml_program_ops_ext.py deleted file mode 100644 index c84d23c16..000000000 --- a/mlir/python/mlir/dialects/_ml_program_ops_ext.py +++ /dev/null @@ -1,113 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from typing import Union - from ..ir import * - from ._ods_common import get_default_loc_context as _get_default_loc_context -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from ._ml_program_ops_gen import * - - -ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" -RESULT_ATTRIBUTE_NAME = "res_attrs" - - -class FuncOp: - """Specialization for the func op class.""" - - def __init__( - self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None - ): - """ - Create a FuncOp with the provided `name`, `type`, and `visibility`. - - `name` is a string representing the function name. - - `type` is either a FunctionType or a pair of list describing inputs and - results. - - `visibility` is a string matching `public`, `private`, or `nested`. None - implies private visibility. - - `body_builder` is an optional callback, when provided a new entry block - is created and the callback is invoked with the new op as argument within - an InsertionPoint context already set for the block. The callback is - expected to insert a terminator in the block. - """ - sym_name = StringAttr.get(str(name)) - - # If the type is passed as a tuple, build a FunctionType on the fly. - if isinstance(type, tuple): - type = FunctionType.get(inputs=type[0], results=type[1]) - - type = TypeAttr.get(type) - sym_visibility = ( - StringAttr.get(str(visibility)) if visibility is not None else None - ) - super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) - if body_builder: - entry_block = self.add_entry_block() - with InsertionPoint(entry_block): - body_builder(self) - - @property - def is_external(self): - return len(self.regions[0].blocks) == 0 - - @property - def body(self): - return self.regions[0] - - @property - def type(self): - return FunctionType(TypeAttr(self.attributes["function_type"]).value) - - @property - def visibility(self): - return self.attributes["sym_visibility"] - - @property - def name(self) -> StringAttr: - return StringAttr(self.attributes["sym_name"]) - - @property - def entry_block(self): - if self.is_external: - raise IndexError("External function does not have a body") - return self.regions[0].blocks[0] - - def add_entry_block(self): - """ - Add an entry block to the function body using the function signature to - infer block arguments. - Returns the newly created block - """ - if not self.is_external: - raise IndexError("The function already has an entry block!") - self.body.blocks.append(*self.type.inputs) - return self.body.blocks[0] - - @property - def arg_attrs(self): - return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) - - @arg_attrs.setter - def arg_attrs(self, attribute: Union[ArrayAttr, list]): - if isinstance(attribute, ArrayAttr): - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute - else: - self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( - attribute, context=self.context - ) - - @property - def arguments(self): - return self.entry_block.arguments - - @property - def result_attrs(self): - return self.attributes[RESULT_ATTRIBUTE_NAME] - - @result_attrs.setter - def result_attrs(self, attribute: ArrayAttr): - self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 895c32281..9cca7d659 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -9,7 +9,6 @@ __all__ = [ "equally_sized_accessor", - "extend_opview_class", "get_default_loc_context", "get_op_result_or_value", "get_op_results_or_values", @@ -18,64 +17,6 @@ ] -def extend_opview_class(ext_module): - """Decorator to extend an OpView class from an extension module. - - Extension modules can expose various entry-points: - Stand-alone class with the same name as a parent OpView class (i.e. - "ReturnOp"). A name-based match is attempted first before falling back - to a below mechanism. - - def select_opview_mixin(parent_opview_cls): - If defined, allows an appropriate mixin class to be selected dynamically - based on the parent OpView class. Should return NotImplemented if a - decision is not made. - - Args: - ext_module: A module from which to locate extensions. Can be None if not - available. - - Returns: - A decorator that takes an OpView subclass and further extends it as - needed. - """ - - def class_decorator(parent_opview_cls: type): - if ext_module is None: - return parent_opview_cls - mixin_cls = NotImplemented - # First try to resolve by name. - try: - mixin_cls = getattr(ext_module, parent_opview_cls.__name__) - except AttributeError: - # Fall back to a select_opview_mixin hook. - try: - select_mixin = getattr(ext_module, "select_opview_mixin") - except AttributeError: - pass - else: - mixin_cls = select_mixin(parent_opview_cls) - - if mixin_cls is NotImplemented or mixin_cls is None: - return parent_opview_cls - - # Have a mixin_cls. Create an appropriate subclass. - try: - - class LocalOpView(mixin_cls, parent_opview_cls): - pass - - except TypeError as e: - raise TypeError( - f"Could not mixin {mixin_cls} into {parent_opview_cls}" - ) from e - LocalOpView.__name__ = parent_opview_cls.__name__ - LocalOpView.__qualname__ = parent_opview_cls.__qualname__ - return LocalOpView - - return class_decorator - - def segmented_accessor(elements, raw_segments, idx): """ Returns a slice of elements corresponding to the idx-th segment. diff --git a/mlir/python/mlir/dialects/_pdl_ops_ext.py b/mlir/python/mlir/dialects/_pdl_ops_ext.py deleted file mode 100644 index fc9de0b7f..000000000 --- a/mlir/python/mlir/dialects/_pdl_ops_ext.py +++ /dev/null @@ -1,271 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import pdl -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Union, Optional, Sequence, Mapping -from ._ods_common import ( - get_op_result_or_value as _get_value, - get_op_results_or_values as _get_values, -) - - -class ApplyNativeConstraintOp: - """Specialization for PDL apply native constraint op class.""" - - def __init__( - self, - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) - - -class ApplyNativeRewriteOp: - """Specialization for PDL apply native rewrite op class.""" - - def __init__( - self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) - - -class AttributeOp: - """Specialization for PDL attribute op class.""" - - def __init__( - self, - valueType: Optional[Union[OpView, Operation, Value]] = None, - value: Optional[Attribute] = None, - *, - loc=None, - ip=None, - ): - valueType = valueType if valueType is None else _get_value(valueType) - result = pdl.AttributeType.get() - super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) - - -class EraseOp: - """Specialization for PDL erase op class.""" - - def __init__( - self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - operation = _get_value(operation) - super().__init__(operation, loc=loc, ip=ip) - - -class OperandOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - type: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - type = type if type is None else _get_value(type) - result = pdl.ValueType.get() - super().__init__(result, valueType=type, loc=loc, ip=ip) - - -class OperandsOp: - """Specialization for PDL operands op class.""" - - def __init__( - self, - types: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - types = types if types is None else _get_value(types) - result = pdl.RangeType.get(pdl.ValueType.get()) - super().__init__(result, valueType=types, loc=loc, ip=ip) - - -class OperationOp: - """Specialization for PDL operand op class.""" - - def __init__( - self, - name: Optional[Union[str, StringAttr]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, - types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if types is None: - types = [] - if attributes is None: - attributes = {} - if args is None: - args = [] - args = _get_values(args) - attrNames = [] - attrValues = [] - for attrName, attrValue in attributes.items(): - attrNames.append(StringAttr.get(attrName)) - attrValues.append(_get_value(attrValue)) - attrNames = ArrayAttr.get(attrNames) - types = _get_values(types) - result = pdl.OperationType.get() - super().__init__( - result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip - ) - - -class PatternOp: - """Specialization for PDL pattern op class.""" - - def __init__( - self, - benefit: Union[IntegerAttr, int], - name: Optional[Union[StringAttr, str]] = None, - *, - loc=None, - ip=None, - ): - """Creates an PDL `pattern` operation.""" - super().__init__(benefit, sym_name=name, loc=loc, ip=ip) - self.regions[0].blocks.append() - - @property - def body(self): - """Return the body (block) of the pattern.""" - return self.regions[0].blocks[0] - - -class ReplaceOp: - """Specialization for PDL replace op class.""" - - def __init__( - self, - op: Union[OpView, Operation, Value], - *, - with_op: Optional[Union[OpView, Operation, Value]] = None, - with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - loc=None, - ip=None, - ): - if with_values is None: - with_values = [] - op = _get_value(op) - with_op = with_op if with_op is None else _get_value(with_op) - with_values = _get_values(with_values) - super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) - - -class ResultOp: - """Specialization for PDL result op class.""" - - def __init__( - self, - parent: Union[OpView, Operation, Value], - index: Union[IntegerAttr, int], - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - result = pdl.ValueType.get() - super().__init__(result, parent, index, loc=loc, ip=ip) - - -class ResultsOp: - """Specialization for PDL results op class.""" - - def __init__( - self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - super().__init__(result, parent, index=index, loc=loc, ip=ip) - - -class RewriteOp: - """Specialization for PDL rewrite op class.""" - - def __init__( - self, - root: Optional[Union[OpView, Operation, Value]] = None, - name: Optional[Union[StringAttr, str]] = None, - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - root = root if root is None else _get_value(root) - args = _get_values(args) - super().__init__(args, root=root, name=name, loc=loc, ip=ip) - - def add_body(self): - """Add body (block) to the rewrite.""" - self.regions[0].blocks.append() - return self.body - - @property - def body(self): - """Return the body (block) of the rewrite.""" - return self.regions[0].blocks[0] - - -class TypeOp: - """Specialization for PDL type op class.""" - - def __init__( - self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None - ): - result = pdl.TypeType.get() - super().__init__(result, constantType=constantType, loc=loc, ip=ip) - - -class TypesOp: - """Specialization for PDL types op class.""" - - def __init__( - self, - constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, - *, - loc=None, - ip=None, - ): - if constantTypes is None: - constantTypes = [] - result = pdl.RangeType.get(pdl.TypeType.get()) - super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_scf_ops_ext.py b/mlir/python/mlir/dialects/_scf_ops_ext.py deleted file mode 100644 index 89cc8a198..000000000 --- a/mlir/python/mlir/dialects/_scf_ops_ext.py +++ /dev/null @@ -1,107 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - -from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, -) - - -class ForOp: - """Specialization for the SCF for op class.""" - - def __init__( - self, - lower_bound, - upper_bound, - step, - iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - """Creates an SCF `for` operation. - - - `lower_bound` is the value to use as lower bound of the loop. - - `upper_bound` is the value to use as upper bound of the loop. - - `step` is the value to use as loop step. - - `iter_args` is a list of additional loop-carried arguments or an operation - producing them as results. - """ - if iter_args is None: - iter_args = [] - iter_args = _get_op_results_or_values(iter_args) - - results = [arg.type for arg in iter_args] - super().__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] - ] - + list(iter_args), - loc=loc, - ip=ip, - ) - ) - self.regions[0].blocks.append(self.operands[0].type, *results) - - @property - def body(self): - """Returns the body (block) of the loop.""" - return self.regions[0].blocks[0] - - @property - def induction_variable(self): - """Returns the induction variable of the loop.""" - return self.body.arguments[0] - - @property - def inner_iter_args(self): - """Returns the loop-carried arguments usable within the loop. - - To obtain the loop-carried operands, use `iter_args`. - """ - return self.body.arguments[1:] - - -class IfOp: - """Specialization for the SCF if op class.""" - - def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): - """Creates an SCF `if` operation. - - - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. - - `hasElse` determines whether the if operation has the else branch. - """ - operands = [] - operands.append(cond) - results = [] - results.extend(results_) - super().__init__( - self.build_generic( - regions=2, results=results, operands=operands, loc=loc, ip=ip - ) - ) - self.regions[0].blocks.append(*[]) - if hasElse: - self.regions[1].blocks.append(*[]) - - @property - def then_block(self): - """Returns the then block of the if operation.""" - return self.regions[0].blocks[0] - - @property - def else_block(self): - """Returns the else block of the if operation.""" - return self.regions[1].blocks[0] diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py deleted file mode 100644 index 3757a3d3b..000000000 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ /dev/null @@ -1,759 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import List, Optional, Sequence, Tuple, Union, overload - -StaticIntLike = Union[int, IntegerAttr] -ValueLike = Union[Operation, OpView, Value] -MixedInt = Union[StaticIntLike, ValueLike] - -IntOrAttrList = Sequence[Union[IntegerAttr, int]] -OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] - -BoolOrAttrList = Sequence[Union[BoolAttr, bool]] -OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] - -MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] - -DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] - - -def _dispatch_dynamic_index_list( - indices: Union[DynamicIndexList, ArrayAttr], -) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: - """Dispatches a list of indices to the appropriate form. - - This is similar to the custom `DynamicIndexList` directive upstream: - provided indices may be in the form of dynamic SSA values or static values, - and they may be scalable (i.e., as a singleton list) or not. This function - dispatches each index into its respective form. It also extracts the SSA - values and static indices from various similar structures, respectively. - """ - dynamic_indices = [] - static_indices = [ShapedType.get_dynamic_size()] * len(indices) - scalable_indices = [False] * len(indices) - - # ArrayAttr: Extract index values. - if isinstance(indices, ArrayAttr): - indices = [idx for idx in indices] - - def process_nonscalable_index(i, index): - """Processes any form of non-scalable index. - - Returns False if the given index was scalable and thus remains - unprocessed; True otherwise. - """ - if isinstance(index, int): - static_indices[i] = index - elif isinstance(index, IntegerAttr): - static_indices[i] = index.value # pytype: disable=attribute-error - elif isinstance(index, (Operation, Value, OpView)): - dynamic_indices.append(index) - else: - return False - return True - - # Process each index at a time. - for i, index in enumerate(indices): - if not process_nonscalable_index(i, index): - # If it wasn't processed, it must be a scalable index, which is - # provided as a Sequence of one value, so extract and process that. - scalable_indices[i] = True - assert len(index) == 1 - ret = process_nonscalable_index(i, index[0]) - assert ret - - return dynamic_indices, static_indices, scalable_indices - - -# Dispatches `MixedValues` that all represents integers in various forms into -# the following three categories: -# - `dynamic_values`: a list of `Value`s, potentially from op results; -# - `packed_values`: a value handle, potentially from an op result, associated -# to one or more payload operations of integer type; -# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python -# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. -# The input is in the form for `packed_values`, only that result is set and the -# other two are empty. Otherwise, the input can be a mix of the other two forms, -# and for each dynamic value, a special value is added to the `static_values`. -def _dispatch_mixed_values( - values: MixedValues, -) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: - dynamic_values = [] - packed_values = None - static_values = None - if isinstance(values, ArrayAttr): - static_values = values - elif isinstance(values, (Operation, Value, OpView)): - packed_values = values - else: - static_values = [] - for size in values or []: - if isinstance(size, int): - static_values.append(size) - else: - static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(size) - static_values = DenseI64ArrayAttr.get(static_values) - - return (dynamic_values, packed_values, static_values) - - -def _get_value_or_attribute_value( - value_or_attr: Union[any, Attribute, ArrayAttr] -) -> any: - if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): - return value_or_attr.value - if isinstance(value_or_attr, ArrayAttr): - return _get_value_list(value_or_attr) - return value_or_attr - - -def _get_value_list( - sequence_or_array_attr: Union[Sequence[any], ArrayAttr] -) -> Sequence[any]: - return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] - - -def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: - if values is None: - return None - - # Turn into a Python list of Python ints. - values = _get_value_list(values) - - # Make an ArrayAttr of IntegerAttrs out of it. - return ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] - ) - - -def _get_int_array_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] -) -> ArrayAttr: - """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - - The input has to be a collection of collection of integers, where any - Python Sequence and ArrayAttr are admissible collections and Python ints and - any IntegerAttr are admissible integers. Both levels of collections are - turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. - If the input is None, an empty ArrayAttr is returned. - """ - if values is None: - return None - - # Make sure the outer level is a list. - values = _get_value_list(values) - - # The inner level is now either invalid or a mixed sequence of ArrayAttrs and - # Sequences. Make sure the nested values are all lists. - values = [_get_value_list(nested) for nested in values] - - # Turn each nested list into an ArrayAttr. - values = [_get_int_array_attr(nested) for nested in values] - - # Turn the outer list into an ArrayAttr. - return ArrayAttr.get(values) - - -class BufferizeToAllocationOp: - """Specialization for BufferizeToAllocationOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - memory_space: Optional[Union[int, str, Attribute]] = None, - memcpy_op: Optional[str] = None, - alloc_op: Optional[str] = None, - bufferize_destination_only: Optional[bool] = None, - loc=None, - ip=None, - ): - # No other types are allowed, so hard-code those here. - allocated_buffer_type = transform.AnyValueType.get() - new_ops_type = transform.AnyOpType.get() - - if isinstance(memory_space, int): - memory_space = str(memory_space) - if isinstance(memory_space, str): - memory_space = Attribute.parse(memory_space) - - super().__init__( - allocated_buffer_type, - new_ops_type, - target, - memory_space=memory_space, - memcpy_op=memcpy_op, - alloc_op=alloc_op, - bufferize_destination_only=bufferize_destination_only, - loc=loc, - ip=ip, - ) - - -class DecomposeOp: - """Specialization for DecomposeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - transformed_type = transform.AnyOpType.get() - super().__init__(transformed_type, target, loc=loc, ip=ip) - - -class FuseIntoContainingOp: - """Specialization for FuseIntoContainingOp class.""" - - @overload - def __init__( - self, - fused_op_type: Type, - new_containing_op_type: Type, - producer_op: Union[Operation, OpView, Value], - containing_op: Union[Operation, OpView, Value], - *, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - producer_op: Union[Operation, OpView, Value], - containing_op: Union[Operation, OpView, Value], - *, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], - new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], - producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, - containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - loc=None, - ip=None, - ): - if isinstance(fused_op_type_or_producer_op, Type): - if not isinstance(new_containing_op_type_or_containing_op, Type): - raise TypeError( - "If 'fused_op_type_or_producer_op' is a type, then " - "'new_containing_op_type_or_containing_op' is expected " - "to be one as well." - ) - fused_op_type = fused_op_type_or_producer_op - new_containing_op_type = new_containing_op_type_or_containing_op - producer_op = producer_op_or_none - containing_op = containing_op_or_none - else: - fused_op_type = transform.AnyOpType.get() - new_containing_op_type = transform.AnyOpType.get() - producer_op = fused_op_type_or_producer_op - containing_op = new_containing_op_type_or_containing_op - - super().__init__( - fused_op_type, - new_containing_op_type, - producer_op, - containing_op, - loc=loc, - ip=ip, - ) - - -class GeneralizeOp: - """Specialization for GeneralizeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - transformed_type = transform.AnyOpType.get() - super().__init__(transformed_type, target, loc=loc, ip=ip) - - -class InterchangeOp: - """Specialization for InterchangeOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - iterator_interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - transformed_type = transform.AnyOpType.get() - super().__init__( - transformed_type, - target, - iterator_interchange=iterator_interchange, - loc=loc, - ip=ip, - ) - - -class MapCopyToThreadsOp: - """Specialization for MapCopyToThreadsOp class.""" - - @overload - def __init__( - self, - forall_op_type: Type, - tiled_op_type: Type, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - ... - - def __init__( - self, - forall_op_type_or_target: Union[Operation, OpView, Type, Value], - tiled_op_type_or_none: Optional[Type] = None, - target_or_none: Optional[Union[Operation, OpView, Value]] = None, - *, - total_num_threads: Union[int, IntegerAttr], - desired_bit_alignment: Union[int, IntegerAttr], - loc=None, - ip=None, - ): - if isinstance(forall_op_type_or_target, Type): - forall_op_type = forall_op_type_or_target - tiled_op_type = tiled_op_type_or_none - target = target_or_none - else: - forall_op_type = transform.AnyOpType.get() - tiled_op_type = transform.AnyOpType.get() - target = forall_op_type_or_target - - super().__init__( - forall_op_type, - tiled_op_type, - target, - total_num_threads=total_num_threads, - desired_bit_alignment=desired_bit_alignment, - loc=loc, - ip=ip, - ) - - -class VectorizeOp: - """Specialization for VectorizeOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - *, - vectorize_nd_extract: Optional[bool] = None, - scalable_sizes: OptionalBoolList = None, - static_vector_sizes: OptionalIntList = None, - loc=None, - ip=None, - ): - if ( - scalable_sizes is None - and static_vector_sizes is None - and vector_sizes is None - ): - dynamic_vector_sizes = [] - elif scalable_sizes is None and static_vector_sizes is None: - ( - dynamic_vector_sizes, - static_vector_sizes, - scalable_sizes, - ) = _dispatch_dynamic_index_list(vector_sizes) - elif scalable_sizes is None or static_vector_sizes is None: - raise TypeError( - "'scalable_sizes' and 'static_vector_sizes' must either both " - "be given explicitly or both be given as part of 'vector_sizes'." - ) - else: - dynamic_vector_sizes = vector_sizes - - super().__init__( - target, - vector_sizes=dynamic_vector_sizes, - static_vector_sizes=static_vector_sizes, - scalable_sizes=scalable_sizes, - vectorize_nd_extract=vectorize_nd_extract, - loc=loc, - ip=ip, - ) - - -class MatchOp: - """Specialization for MatchOp class.""" - - @overload - @classmethod - def match_op_names( - cls, - target: Union[Operation, Value], - names: Union[str, Sequence[str]], - *, - loc=None, - ip=None, - ): - ... - - @overload - @classmethod - def match_op_names( - cls, - result_type: Type, - target: Union[Operation, Value], - names: Union[str, Sequence[str]], - *, - loc=None, - ip=None, - ): - ... - - @classmethod - def match_op_names( - cls, - result_type_or_target: Union[Type, Operation, Value], - target_or_names: Union[Operation, Value, Sequence[str], str], - names_or_none: Optional[Union[Sequence[str], str]] = None, - *, - loc=None, - ip=None, - ): - if isinstance(result_type_or_target, Type): - result_type = result_type_or_target - target = target_or_names - names = names_or_none - else: - result_type = transform.AnyOpType.get() - target = result_type_or_target - names = target_or_names - - if isinstance(names, str): - names = [names] - - return cls( - result_type, - target, - ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), - loc=loc, - ip=ip, - ) - - -class MultiTileSizesOp: - """Specialization for MultiTileSizesOp class.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - dimension: Union[int, IntegerAttr], - target_size: Union[int, IntegerAttr], - divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, - loc=None, - ip=None, - ): - super().__init__( - result_type, - result_type, - result_type, - target, - dimension=dimension, - target_size=target_size, - divisor=divisor, - loc=loc, - ip=ip, - ) - - -class PadOp: - """Specialization for PadOp class.""" - - def __init__( - self, - target: Union[Operation, OpView, Value], - *, - padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, - padding_dimensions: OptionalIntList = None, - pad_to_multiple_of: OptionalIntList = None, - pack_paddings: OptionalIntList = None, - transpose_paddings: Optional[ - Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] - ] = None, - copy_back_op: Optional[Union[str, StringAttr]] = None, - loc=None, - ip=None, - ): - transpose_paddings = _get_int_array_array_attr(transpose_paddings) - - any_op_type = transform.AnyOpType.get() - super().__init__( - any_op_type, - any_op_type, - any_op_type, - target, - padding_values=padding_values, - padding_dimensions=padding_dimensions, - pad_to_multiple_of=pad_to_multiple_of, - pack_paddings=pack_paddings, - transpose_paddings=transpose_paddings, - copy_back_op=copy_back_op, - loc=loc, - ip=ip, - ) - - -class ScalarizeOp: - """Specialization for ScalarizeOp class.""" - - def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): - result_type = transform.AnyOpType.get() - super().__init__(result_type, target, loc=loc, ip=ip) - - -class SplitOp: - """Specialization for SplitOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], - *, - loc=None, - ip=None, - ): - if isinstance(split_point, int): - static_split_point = split_point - dynamic_split_point = None - else: - static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = split_point - - super().__init__( - target.type, - target.type, - target, - dimension=dimension, - static_split_point=static_split_point, - dynamic_split_point=dynamic_split_point, - loc=loc, - ip=ip, - ) - - -class TileUsingForOp: - """Specialization for TileUsingForOp class.""" - - @overload - def __init__( - self, - loop_types: Union[Type, List[Type]], - target: Union[Operation, Value], - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loop_types_or_target: Union[Type, List[Type], Operation, Value], - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, - interchange: OptionalIntList = None, - loc=None, - ip=None, - ): - ( - dynamic_sizes, - static_sizes, - scalable_sizes, - ) = _dispatch_dynamic_index_list(sizes) - - num_loops = sum(v if v == 0 else 1 for v in static_sizes) - - if isinstance(loop_types_or_target, (Operation, Value, OpView)): - loop_types = [transform.AnyOpType.get()] * num_loops - target = loop_types_or_target - assert ( - target_or_none is None - ), "Cannot construct TileUsingForOp with two targets." - else: - loop_types = ( - ([loop_types_or_target] * num_loops) - if isinstance(loop_types_or_target, Type) - else loop_types_or_target - ) - target = target_or_none - - super().__init__( - target.type, - loop_types, - target, - dynamic_sizes=dynamic_sizes, - static_sizes=static_sizes, - interchange=interchange, - scalable_sizes=scalable_sizes, - loc=loc, - ip=ip, - ) - - -class TileUsingForallOp: - """Specialization for TileUsingForallOp class.""" - - @overload - def __init__( - self, - loops_type: Type, - tiled_op_type: Type, - target: Union[Operation, Value, OpView], - *, - num_threads: Optional[MixedValues] = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - num_threads: Optional[MixedValues] = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - ... - - def __init__( - self, - loops_type_or_target: Union[ - Type, Union[Operation, Value, OpView] # loops_type - ], # target - tiled_op_type_or_none: Optional[Type] = None, - target_or_none: Optional[Union[Operation, Value, OpView]] = None, - *, - num_threads: MixedValues = None, - tile_sizes: MixedValues = None, - mapping=None, - loc=None, - ip=None, - ): - # `Type` arguments in the front are optional: add default values to front. - if isinstance(loops_type_or_target, Type): - # First overload: type arguments provided. - if not isinstance(tiled_op_type_or_none, Type): - raise TypeError( - "If 'loops_type_or_target' is a type, then " - "'tiled_op_type_or_none' is expected to be one as well." - ) - loops_type = loops_type_or_target - tiled_op_type = tiled_op_type_or_none - target = target_or_none - else: - # Last overload: type arguments missing. - loops_type = transform.AnyOpType.get() - tiled_op_type = transform.AnyOpType.get() - target = loops_type_or_target - - # Unpack mixed num_threads. - ( - dynamic_num_threads, - packed_num_threads, - num_threads_attr, - ) = _dispatch_mixed_values(num_threads) - - # Unpack mixed tile_sizes. - ( - dynamic_tile_sizes, - packed_tile_sizes, - tile_sizes_attr, - ) = _dispatch_mixed_values(tile_sizes) - - super().__init__( - loops_type, - tiled_op_type, - target=target, - tile_sizes=dynamic_tile_sizes, - packed_tile_sizes=packed_tile_sizes, - static_tile_sizes=tile_sizes_attr, - num_threads=dynamic_num_threads, - packed_num_threads=packed_num_threads, - static_num_threads=num_threads_attr, - mapping=mapping, - loc=loc, - ip=ip, - ) - - -class VectorizeChildrenAndApplyPatternsOp: - """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" - - def __init__( - self, - target: Union[Operation, Value], - *, - disable_multi_reduction_to_contract_patterns: bool = False, - disable_transfer_permutation_map_lowering_patterns: bool = False, - vectorize_nd_extract: bool = False, - vectorize_padding: bool = False, - loc=None, - ip=None, - ): - transformed_type = transform.AnyOpType.get() - super().__init__( - transformed_type, - target, - disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, - disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, - vectorize_nd_extract=vectorize_nd_extract, - vectorize_padding=vectorize_padding, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_tensor_ops_ext.py b/mlir/python/mlir/dialects/_tensor_ops_ext.py deleted file mode 100644 index 09b9ec68d..000000000 --- a/mlir/python/mlir/dialects/_tensor_ops_ext.py +++ /dev/null @@ -1,44 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Any, Optional, Sequence, Union -from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, -) - - -class EmptyOp: - """Extends the tensor.empty op.""" - - def __init__( - self, - sizes: Sequence[Union[int, Value]], - element_type: Type, - *, - loc=None, - ip=None - ): - """Constructs an `empty` with mixed static/dynamic sizes.""" - # TODO: Refactor the EmptyOp to take an element type attribute and - # then use normal result type inference, unifying the Python and C++ side - # with a standard mechanism (versus stashing that in builders). - dynamic_sizes = [] - static_sizes = [] - for s in sizes: - if isinstance(s, int): - static_sizes.append(s) - else: - static_sizes.append(ShapedType.get_dynamic_size()) - dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) - op = self.build_generic( - results=[result_type], operands=dynamic_sizes, attributes={}, loc=loc, ip=ip - ) - OpView.__init__(self, op) diff --git a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py b/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py deleted file mode 100644 index 996093fbc..000000000 --- a/mlir/python/mlir/dialects/_tensor_transform_ops_ext.py +++ /dev/null @@ -1,64 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ..dialects import transform -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, overload, Union - - -class MakeLoopIndependentOp: - """Specialization for MakeLoopIndependentOp class.""" - - @overload - def __init__( - self, - transformed_type: Type, - target: Union[Operation, OpView, Value], - num_loops: Union[int, IntegerAttr], - *, - loc=None, - ip=None - ): - ... - - @overload - def __init__( - self, - target: Union[Operation, OpView, Value], - num_loops: Union[int, IntegerAttr], - *, - loc=None, - ip=None - ): - ... - - def __init__( - self, - transformed_type_or_target: Type, - target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, - num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, - *, - loc=None, - ip=None - ): - if isinstance(transformed_type_or_target, Type): - transformed_type = transformed_type_or_target - target = target_or_num_loops - num_loops = num_loops_or_none - else: - transformed_type = transform.AnyOpType.get() - target = transformed_type_or_target - num_loops = target_or_num_loops - - super().__init__( - transformed_type, - target, - num_loops, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py deleted file mode 100644 index b1e7b8925..000000000 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ /dev/null @@ -1,176 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -class CastOp: - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - loc=None, - ip=None, - ): - super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) - - -class ApplyPatternsOp: - def __init__( - self, - target: Union[Operation, Value, OpView], - *, - loc=None, - ip=None, - ): - operands = [] - operands.append(_get_op_result_or_value(target)) - super().__init__( - self.build_generic( - attributes={}, - results=[], - operands=operands, - successors=None, - regions=None, - loc=loc, - ip=ip, - ) - ) - self.regions[0].blocks.append() - - @property - def patterns(self) -> Block: - return self.regions[0].blocks[0] - - -class testGetParentOp: - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) - - -class MergeHandlesOp: - def __init__( - self, - handles: Sequence[Union[Operation, Value]], - *, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h) for h in handles], - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) - - -class ReplicateOp: - def __init__( - self, - pattern: Union[Operation, Value], - handles: Sequence[Union[Operation, Value]], - *, - loc=None, - ip=None, - ): - super().__init__( - [_get_op_result_or_value(h).type for h in handles], - _get_op_result_or_value(pattern), - [_get_op_result_or_value(h) for h in handles], - loc=loc, - ip=ip, - ) - - -class SequenceOp: - def __init__( - self, - failure_propagation_mode, - results: Sequence[Type], - target: Union[Operation, Value, Type], - extra_bindings: Optional[ - Union[Sequence[Value], Sequence[Type], Operation, OpView] - ] = None, - ): - root = ( - _get_op_result_or_value(target) - if isinstance(target, (Operation, Value)) - else None - ) - root_type = root.type if not isinstance(target, Type) else target - - if extra_bindings is None: - extra_bindings = [] - if isinstance(extra_bindings, (Operation, OpView)): - extra_bindings = _get_op_results_or_values(extra_bindings) - - extra_binding_types = [] - if len(extra_bindings) != 0: - if isinstance(extra_bindings[0], Type): - extra_binding_types = extra_bindings - extra_bindings = [] - else: - extra_binding_types = [v.type for v in extra_bindings] - - super().__init__( - results_=results, - failure_propagation_mode=failure_propagation_mode, - root=root, - extra_bindings=extra_bindings, - ) - self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - @property - def bodyExtraArgs(self) -> BlockArgumentList: - return self.body.arguments[1:] - - -class YieldOp: - def __init__( - self, - operands: Optional[Union[Operation, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - if operands is None: - operands = [] - super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py deleted file mode 100644 index c4e4b4b42..000000000 --- a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py +++ /dev/null @@ -1,55 +0,0 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Union - -class PDLMatchOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - pattern_name, - loc=loc, - ip=ip, - ) - - -class WithPDLPatternsOp: - - def __init__(self, - target: Union[Operation, Value, Type], - *, - loc=None, - ip=None): - root = _get_op_result_or_value(target) if not isinstance(target, - Type) else None - root_type = target if isinstance(target, Type) else root.type - super().__init__(root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(root_type) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 8a2a64c7c..1eaccfa73 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -1,5 +1,50 @@ -# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._affine_ops_gen import * +from ._affine_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineStoreOp(AffineStoreOp): + """Specialization for the Affine store operation.""" + + def __init__( + self, + value: Union[Operation, OpView, Value], + memref: Union[Operation, OpView, Value], + map: AffineMap = None, + *, + map_operands=None, + loc=None, + ip=None, + ): + """Creates an affine store operation. + + - `value`: the value to store into the memref. + - `memref`: the buffer to store into. + - `map`: the affine map that maps the map_operands to the index of the + memref. + - `map_operands`: the list of arguments to substitute the dimensions, + then symbols in the affine map, in increasing order. + """ + map = map if map is not None else [] + map_operands = map_operands if map_operands is not None else [] + indicies = [_get_op_result_or_value(op) for op in map_operands] + _ods_successors = None + super().__init__( + value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip + ) diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index fb13beb63..83aca0d58 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -3,4 +3,75 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._arith_ops_gen import * +from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +def _isa(obj: Any, cls: type): + try: + cls(obj) + except ValueError: + return False + return True + + +def _is_any_of(obj: Any, classes: List[type]): + return any(_isa(obj, cls) for cls in classes) + + +def _is_integer_like_type(type: Type): + return _is_any_of(type, [IntegerType, IndexType]) + + +def _is_float_type(type: Type): + return _is_any_of(type, [BF16Type, F16Type, F32Type, F64Type]) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + def __init__( + self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + ): + if isinstance(value, int): + super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, float): + super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + else: + super().__init__(value, loc=loc, ip=ip) + + @classmethod + def create_index(cls, value: int, *, loc=None, ip=None): + """Create an index-typed constant.""" + return cls( + IndexType.get(context=_get_default_loc_context(loc)), value, loc=loc, ip=ip + ) + + @property + def type(self): + return self.results[0].type + + @property + def value(self): + return Attribute(self.operation.attributes["value"]) + + @property + def literal_value(self) -> Union[int, float]: + if _is_integer_like_type(self.type): + return IntegerAttr(self.value).value + elif _is_float_type(self.type): + return FloatAttr(self.value).value + else: + raise ValueError("only integer and float constants have literal values") diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py index 759b6aa24..0ce5448ac 100644 --- a/mlir/python/mlir/dialects/bufferization.py +++ b/mlir/python/mlir/dialects/bufferization.py @@ -3,4 +3,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._bufferization_ops_gen import * +from ._bufferization_ops_gen import _Dialect from ._bufferization_enum_gen import * + +try: + from typing import Sequence, Union + from ..ir import * + from ._ods_common import get_default_loc_context, _cext as _ods_cext + + from typing import Any, List, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AllocTensorOp(AllocTensorOp): + """Extends the bufferization.alloc_tensor op.""" + + def __init__( + self, + tensor_type: Type, + dynamic_sizes: Sequence[Value], + copy: Value, + size_hint: Value, + escape: BoolAttr, + *, + loc=None, + ip=None, + ): + """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" + super().__init__( + tensor_type, + dynamic_sizes, + copy=copy, + size_hint=size_hint, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py index 30279e161..b71cc2466 100644 --- a/mlir/python/mlir/dialects/builtin.py +++ b/mlir/python/mlir/dialects/builtin.py @@ -3,3 +3,23 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._builtin_ops_gen import * +from ._builtin_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ModuleOp(ModuleOp): + """Specialization for the module op class.""" + + def __init__(self, *, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + body = self.regions[0].blocks.append() + + @property + def body(self): + return self.regions[0].blocks[0] diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index dc554c221..9c6c4c909 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -3,3 +3,326 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._func_ops_gen import * +from ._func_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) + + import inspect + + from typing import Any, List, Optional, Sequence, Union +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ConstantOp(ConstantOp): + """Specialization for the constant op class.""" + + def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): + super().__init__(result, value, loc=loc, ip=ip) + + @property + def type(self): + return self.results[0].type + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute + + @classmethod + def from_py_func( + FuncOp, + *inputs: Type, + results: Optional[Sequence[Type]] = None, + name: Optional[str] = None, + ): + """Decorator to define an MLIR FuncOp specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + When applied as a decorator to a Python function, an entry block will + be constructed for the FuncOp with types as specified in `*inputs`. The + block arguments will be passed positionally to the Python function. In + addition, if the Python function accepts keyword arguments generally or + has a corresponding keyword argument, the following will be passed: + * `func_op`: The `func` op being defined. + + By default, the function name will be the Python function `__name__`. This + can be overriden by passing the `name` argument to the decorator. + + If `results` is not specified, then the decorator will implicitly + insert a `ReturnOp` with the `Value`'s returned from the decorated + function. It will also set the `FuncOp` type with the actual return + value types. If `results` is specified, then the decorated function + must return `None` and no implicit `ReturnOp` is added (nor are the result + types updated). The implicit behavior is intended for simple, single-block + cases, and users should specify result types explicitly for any complicated + cases. + + The decorated function can further be called from Python and will insert + a `CallOp` at the then-current insertion point, returning either None ( + if no return values), a unary Value (for one result), or a list of Values). + This mechanism cannot be used to emit recursive calls (by construction). + """ + + def decorator(f): + from . import func + + # Introspect the callable for optional features. + sig = inspect.signature(f) + has_arg_func_op = False + for param in sig.parameters.values(): + if param.kind == param.VAR_KEYWORD: + has_arg_func_op = True + if param.name == "func_op" and ( + param.kind == param.POSITIONAL_OR_KEYWORD + or param.kind == param.KEYWORD_ONLY + ): + has_arg_func_op = True + + # Emit the FuncOp. + implicit_return = results is None + symbol_name = name or f.__name__ + function_type = FunctionType.get( + inputs=inputs, results=[] if implicit_return else results + ) + func_op = FuncOp(name=symbol_name, type=function_type) + with InsertionPoint(func_op.add_entry_block()): + func_args = func_op.entry_block.arguments + func_kwargs = {} + if has_arg_func_op: + func_kwargs["func_op"] = func_op + return_values = f(*func_args, **func_kwargs) + if not implicit_return: + return_types = list(results) + assert return_values is None, ( + "Capturing a python function with explicit `results=` " + "requires that the wrapped function returns None." + ) + else: + # Coerce return values, add ReturnOp and rewrite func type. + if return_values is None: + return_values = [] + elif isinstance(return_values, tuple): + return_values = list(return_values) + elif isinstance(return_values, Value): + # Returning a single value is fine, coerce it into a list. + return_values = [return_values] + elif isinstance(return_values, OpView): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.operation.results + elif isinstance(return_values, Operation): + # Returning a single operation is fine, coerce its results a list. + return_values = return_values.results + else: + return_values = list(return_values) + func.ReturnOp(return_values) + # Recompute the function type. + return_types = [v.type for v in return_values] + function_type = FunctionType.get( + inputs=inputs, results=return_types + ) + func_op.attributes["function_type"] = TypeAttr.get(function_type) + + def emit_call_op(*call_args): + call_op = func.CallOp( + return_types, FlatSymbolRefAttr.get(symbol_name), call_args + ) + if return_types is None: + return None + elif len(return_types) == 1: + return call_op.result + else: + return call_op.results + + wrapped = emit_call_op + wrapped.__name__ = f.__name__ + wrapped.func_op = func_op + return wrapped + + return decorator + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CallOp(CallOp): + """Specialization for the call op class.""" + + def __init__( + self, + calleeOrResults: Union[FuncOp, List[Type]], + argumentsOrCallee: Union[List, FlatSymbolRefAttr, str], + arguments: Optional[List] = None, + *, + loc=None, + ip=None, + ): + """Creates an call operation. + + The constructor accepts three different forms: + + 1. A function op to be called followed by a list of arguments. + 2. A list of result types, followed by the name of the function to be + called as string, following by a list of arguments. + 3. A list of result types, followed by the name of the function to be + called as symbol reference attribute, followed by a list of arguments. + + For example + + f = func.FuncOp("foo", ...) + func.CallOp(f, [args]) + func.CallOp([result_types], "foo", [args]) + + In all cases, the location and insertion point may be specified as keyword + arguments if not provided by the surrounding context managers. + """ + + # TODO: consider supporting constructor "overloads", e.g., through a custom + # or pybind-provided metaclass. + if isinstance(calleeOrResults, FuncOp): + if not isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function, expected " + + "the second argument to be a list of call arguments, " + + f"got {type(argumentsOrCallee)}" + ) + if arguments is not None: + raise ValueError( + "unexpected third argument when constructing a call" + + "to a function" + ) + + super().__init__( + calleeOrResults.type.results, + FlatSymbolRefAttr.get( + calleeOrResults.name.value, context=_get_default_loc_context(loc) + ), + argumentsOrCallee, + loc=loc, + ip=ip, + ) + return + + if isinstance(argumentsOrCallee, list): + raise ValueError( + "when constructing a call to a function by name, " + + "expected the second argument to be a string or a " + + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}" + ) + + if isinstance(argumentsOrCallee, FlatSymbolRefAttr): + super().__init__( + calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip + ) + elif isinstance(argumentsOrCallee, str): + super().__init__( + calleeOrResults, + FlatSymbolRefAttr.get( + argumentsOrCallee, context=_get_default_loc_context(loc) + ), + arguments, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 6f9d72164..f91fc8b71 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -310,7 +310,7 @@ def emit_named_structured_op( ) # Set the index attributes used to compute the indexing maps. - named_op = getattr(linalg, op_class_name)(ins, outs, result_types) + named_op = getattr(linalg, op_class_name)(result_types, ins, outs) for name, value in index_attrs.items(): named_op.operation.attributes[name] = value diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index a8f8f8e0f..19734a80a 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -296,35 +296,39 @@ def quantized_matmul( @linalg_structured_op -def matmul_transpose_a(A=TensorDef(T1, S.K, S.N), - B=TensorDef(T2, S.K, S.M), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs with lhs operand - transposed. +def matmul_transpose_a( + A=TensorDef(T1, S.K, S.N), + B=TensorDef(T2, S.K, S.M), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs with lhs operand + transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) @linalg_structured_op -def matmul_transpose_b(A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.N, S.K), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed)): - """Performs a matrix multiplication of two 2D inputs with rhs operand - transposed. +def matmul_transpose_b( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.N, S.K), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs with rhs operand + transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) @linalg_structured_op @@ -390,36 +394,41 @@ def batch_matmul( @linalg_structured_op -def batch_matmul_transpose_a(A=TensorDef(T1, Batch, S.K, S.M), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs where lhs operand - has its non-batch dimensions transposed. +def batch_matmul_transpose_a( + A=TensorDef(T1, Batch, S.K, S.M), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs where lhs operand + has its non-batch dimensions transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) \ - * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) @linalg_structured_op -def batch_matmul_transpose_b(A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.N, S.K), - C=TensorDef(U, Batch, S.M, S.N, output=True)): - """Performs a batched matrix multiplication of two 3D inputs where rhs operand - has its non-batch dimensions transposed. +def batch_matmul_transpose_b( + A=TensorDef(T1, Batch, S.M, S.K), + B=TensorDef(T2, Batch, S.N, S.K), + C=TensorDef(U, Batch, S.M, S.N, output=True), +): + """Performs a batched matrix multiplication of two 3D inputs where rhs operand + has its non-batch dimensions transposed. - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, - D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.n, D.k]) + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.n, D.k] + ) @linalg_structured_op diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 3afb6a70c..111ad2178 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -3,3 +3,41 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._memref_ops_gen import * +from ._memref_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoadOp(LoadOp): + """Specialization for the MemRef load operation.""" + + def __init__( + self, + memref: Union[Operation, OpView, Value], + indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + """Creates a memref load operation. + + Args: + memref: the buffer to load from. + indices: the list of subscripts, may be empty for zero-dimensional + buffers. + loc: user-visible location of the operation. + ip: insertion point. + """ + indices_resolved = [] if indices is None else _get_op_results_or_values(indices) + super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/ml_program.py b/mlir/python/mlir/dialects/ml_program.py index a654529b4..dfb6d7f2c 100644 --- a/mlir/python/mlir/dialects/ml_program.py +++ b/mlir/python/mlir/dialects/ml_program.py @@ -2,4 +2,118 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Union + from ._ml_program_ops_gen import * +from ._ml_program_ops_gen import _Dialect + +try: + from ..ir import * + from ._ods_common import ( + get_default_loc_context as _get_default_loc_context, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + + +ARGUMENT_ATTRIBUTE_NAME = "arg_attrs" +RESULT_ATTRIBUTE_NAME = "res_attrs" + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuncOp(FuncOp): + """Specialization for the func op class.""" + + def __init__( + self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None + ): + """ + Create a FuncOp with the provided `name`, `type`, and `visibility`. + - `name` is a string representing the function name. + - `type` is either a FunctionType or a pair of list describing inputs and + results. + - `visibility` is a string matching `public`, `private`, or `nested`. None + implies private visibility. + - `body_builder` is an optional callback, when provided a new entry block + is created and the callback is invoked with the new op as argument within + an InsertionPoint context already set for the block. The callback is + expected to insert a terminator in the block. + """ + sym_name = StringAttr.get(str(name)) + + # If the type is passed as a tuple, build a FunctionType on the fly. + if isinstance(type, tuple): + type = FunctionType.get(inputs=type[0], results=type[1]) + + type = TypeAttr.get(type) + sym_visibility = ( + StringAttr.get(str(visibility)) if visibility is not None else None + ) + super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip) + if body_builder: + entry_block = self.add_entry_block() + with InsertionPoint(entry_block): + body_builder(self) + + @property + def is_external(self): + return len(self.regions[0].blocks) == 0 + + @property + def body(self): + return self.regions[0] + + @property + def type(self): + return FunctionType(TypeAttr(self.attributes["function_type"]).value) + + @property + def visibility(self): + return self.attributes["sym_visibility"] + + @property + def name(self) -> StringAttr: + return StringAttr(self.attributes["sym_name"]) + + @property + def entry_block(self): + if self.is_external: + raise IndexError("External function does not have a body") + return self.regions[0].blocks[0] + + def add_entry_block(self): + """ + Add an entry block to the function body using the function signature to + infer block arguments. + Returns the newly created block + """ + if not self.is_external: + raise IndexError("The function already has an entry block!") + self.body.blocks.append(*self.type.inputs) + return self.body.blocks[0] + + @property + def arg_attrs(self): + return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) + + @arg_attrs.setter + def arg_attrs(self, attribute: Union[ArrayAttr, list]): + if isinstance(attribute, ArrayAttr): + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute + else: + self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get( + attribute, context=self.context + ) + + @property + def arguments(self): + return self.entry_block.arguments + + @property + def result_attrs(self): + return self.attributes[RESULT_ATTRIBUTE_NAME] + + @result_attrs.setter + def result_attrs(self, attribute: ArrayAttr): + self.attributes[RESULT_ATTRIBUTE_NAME] = attribute diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index dda2b7d65..a8d9c56f4 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -3,4 +3,289 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._pdl_ops_gen import * +from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * + + +try: + from ..ir import * + from ..dialects import pdl +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union, Optional, Sequence, Mapping +from ._ods_common import ( + get_op_result_or_value as _get_value, + get_op_results_or_values as _get_values, + _cext as _ods_cext, +) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyNativeConstraintOp(ApplyNativeConstraintOp): + """Specialization for PDL apply native constraint op class.""" + + def __init__( + self, + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(name, args, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyNativeRewriteOp(ApplyNativeRewriteOp): + """Specialization for PDL apply native rewrite op class.""" + + def __init__( + self, + results: Sequence[Type], + name: Union[str, StringAttr], + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + args = _get_values(args) + super().__init__(results, name, args, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AttributeOp(AttributeOp): + """Specialization for PDL attribute op class.""" + + def __init__( + self, + valueType: Optional[Union[OpView, Operation, Value]] = None, + value: Optional[Attribute] = None, + *, + loc=None, + ip=None, + ): + valueType = valueType if valueType is None else _get_value(valueType) + result = pdl.AttributeType.get() + super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EraseOp(EraseOp): + """Specialization for PDL erase op class.""" + + def __init__( + self, + operation: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + operation = _get_value(operation) + super().__init__(operation, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandOp(OperandOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + type: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + type = type if type is None else _get_value(type) + result = pdl.ValueType.get() + super().__init__(result, valueType=type, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperandsOp(OperandsOp): + """Specialization for PDL operands op class.""" + + def __init__( + self, + types: Optional[Union[OpView, Operation, Value]] = None, + *, + loc=None, + ip=None, + ): + types = types if types is None else _get_value(types) + result = pdl.RangeType.get(pdl.ValueType.get()) + super().__init__(result, valueType=types, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OperationOp(OperationOp): + """Specialization for PDL operand op class.""" + + def __init__( + self, + name: Optional[Union[str, StringAttr]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + attributes: Optional[Mapping[str, Union[OpView, Operation, Value]]] = None, + types: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if types is None: + types = [] + if attributes is None: + attributes = {} + if args is None: + args = [] + args = _get_values(args) + attrNames = [] + attrValues = [] + for attrName, attrValue in attributes.items(): + attrNames.append(StringAttr.get(attrName)) + attrValues.append(_get_value(attrValue)) + attrNames = ArrayAttr.get(attrNames) + types = _get_values(types) + result = pdl.OperationType.get() + super().__init__( + result, args, attrValues, attrNames, types, opName=name, loc=loc, ip=ip + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PatternOp(PatternOp): + """Specialization for PDL pattern op class.""" + + def __init__( + self, + benefit: Union[IntegerAttr, int], + name: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + """Creates an PDL `pattern` operation.""" + super().__init__(benefit, sym_name=name, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def body(self): + """Return the body (block) of the pattern.""" + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplaceOp(ReplaceOp): + """Specialization for PDL replace op class.""" + + def __init__( + self, + op: Union[OpView, Operation, Value], + *, + with_op: Optional[Union[OpView, Operation, Value]] = None, + with_values: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + loc=None, + ip=None, + ): + if with_values is None: + with_values = [] + op = _get_value(op) + with_op = with_op if with_op is None else _get_value(with_op) + with_values = _get_values(with_values) + super().__init__(op, with_values, replOperation=with_op, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ResultOp(ResultOp): + """Specialization for PDL result op class.""" + + def __init__( + self, + parent: Union[OpView, Operation, Value], + index: Union[IntegerAttr, int], + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + result = pdl.ValueType.get() + super().__init__(result, parent, index, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ResultsOp(ResultsOp): + """Specialization for PDL results op class.""" + + def __init__( + self, + result: Type, + parent: Union[OpView, Operation, Value], + index: Optional[Union[IntegerAttr, int]] = None, + *, + loc=None, + ip=None, + ): + parent = _get_value(parent) + super().__init__(result, parent, index=index, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class RewriteOp(RewriteOp): + """Specialization for PDL rewrite op class.""" + + def __init__( + self, + root: Optional[Union[OpView, Operation, Value]] = None, + name: Optional[Union[StringAttr, str]] = None, + args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, + *, + loc=None, + ip=None, + ): + if args is None: + args = [] + root = root if root is None else _get_value(root) + args = _get_values(args) + super().__init__(args, root=root, name=name, loc=loc, ip=ip) + + def add_body(self): + """Add body (block) to the rewrite.""" + self.regions[0].blocks.append() + return self.body + + @property + def body(self): + """Return the body (block) of the rewrite.""" + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypeOp(TypeOp): + """Specialization for PDL type op class.""" + + def __init__( + self, constantType: Optional[Union[TypeAttr, Type]] = None, *, loc=None, ip=None + ): + result = pdl.TypeType.get() + super().__init__(result, constantType=constantType, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TypesOp(TypesOp): + """Specialization for PDL types op class.""" + + def __init__( + self, + constantTypes: Optional[Sequence[Union[TypeAttr, Type]]] = None, + *, + loc=None, + ip=None, + ): + if constantTypes is None: + constantTypes = [] + result = pdl.RangeType.get(pdl.TypeType.get()) + super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 8465af048..6579e02d8 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,7 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import TestAttr, TestType, TestTensorValue, TestIntegerRankedTensorType +from .._mlir_libs._mlirPythonTest import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, +) def register_python_test_dialect(context, load=True): diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 49685ca22..43ad9f4e2 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -2,11 +2,122 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional, Sequence from ._scf_ops_gen import * +from ._scf_ops_gen import _Dialect from .arith import constant -from ..ir import * + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +_ForOp = ForOp + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForOp(_ForOp): + """Specialization for the SCF for op class.""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + """Creates an SCF `for` operation. + + - `lower_bound` is the value to use as lower bound of the loop. + - `upper_bound` is the value to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + + results = [arg.type for arg in iter_args] + super(_ForOp, self).__init__( + self.build_generic( + regions=1, + results=results, + operands=[ + _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] + ] + + list(iter_args), + loc=loc, + ip=ip, + ) + ) + self.regions[0].blocks.append(self.operands[0].type, *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] + + +_IfOp = IfOp + + +@_ods_cext.register_operation(_Dialect, replace=True) +class IfOp(_IfOp): + """Specialization for the SCF if op class.""" + + def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): + """Creates an SCF `if` operation. + + - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. + - `hasElse` determines whether the if operation has the else branch. + """ + operands = [] + operands.append(cond) + results = [] + results.extend(results_) + super(_IfOp, self).__init__( + self.build_generic( + regions=2, results=results, operands=operands, loc=loc, ip=ip + ) + ) + self.regions[0].blocks.append(*[]) + if hasElse: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self): + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self): + """Returns the else block of the if operation.""" + return self.regions[1].blocks[0] def for_( diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 26edf6b64..67248748e 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -3,3 +3,40 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._tensor_ops_gen import * +from ._tensor_ops_gen import _Dialect + +try: + from ..ir import * +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Sequence, Union +from ._ods_common import _cext as _ods_cext + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyOp(EmptyOp): + """Extends the tensor.empty op.""" + + def __init__( + self, + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None, + ): + """Constructs an `empty` with mixed static/dynamic sizes.""" + # TODO: Refactor the EmptyOp to take an element type attribute and + # then use normal result type inference, unifying the Python and C++ side + # with a standard mechanism (versus stashing that in builders). + dynamic_sizes = [] + static_sizes = [] + for s in sizes: + if isinstance(s, int): + static_sizes.append(s) + else: + static_sizes.append(ShapedType.get_dynamic_size()) + dynamic_sizes.append(s) + result_type = RankedTensorType.get(static_sizes, element_type) + super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index b020ad35f..f7a2026e8 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -4,4 +4,174 @@ from .._transform_enum_gen import * from .._transform_ops_gen import * +from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class CastOp(CastOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + loc=None, + ip=None, + ): + super().__init__(result_type, _get_op_result_or_value(target), loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyPatternsOp(ApplyPatternsOp): + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + loc=None, + ip=None, + ): + super().__init__(target, loc=loc, ip=ip) + self.regions[0].blocks.append() + + @property + def patterns(self) -> Block: + return self.regions[0].blocks[0] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GetParentOp(GetParentOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MergeHandlesOp(MergeHandlesOp): + def __init__( + self, + handles: Sequence[Union[Operation, Value]], + *, + deduplicate: bool = False, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h) for h in handles], + deduplicate=deduplicate, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ReplicateOp(ReplicateOp): + def __init__( + self, + pattern: Union[Operation, Value], + handles: Sequence[Union[Operation, Value]], + *, + loc=None, + ip=None, + ): + super().__init__( + [_get_op_result_or_value(h).type for h in handles], + _get_op_result_or_value(pattern), + [_get_op_result_or_value(h) for h in handles], + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SequenceOp(SequenceOp): + def __init__( + self, + failure_propagation_mode, + results: Sequence[Type], + target: Union[Operation, Value, Type], + extra_bindings: Optional[ + Union[Sequence[Value], Sequence[Type], Operation, OpView] + ] = None, + ): + root = ( + _get_op_result_or_value(target) + if isinstance(target, (Operation, Value)) + else None + ) + root_type = root.type if not isinstance(target, Type) else target + + if extra_bindings is None: + extra_bindings = [] + if isinstance(extra_bindings, (Operation, OpView)): + extra_bindings = _get_op_results_or_values(extra_bindings) + + extra_binding_types = [] + if len(extra_bindings) != 0: + if isinstance(extra_bindings[0], Type): + extra_binding_types = extra_bindings + extra_bindings = [] + else: + extra_binding_types = [v.type for v in extra_bindings] + + super().__init__( + results_=results, + failure_propagation_mode=failure_propagation_mode, + root=root, + extra_bindings=extra_bindings, + ) + self.regions[0].blocks.append(*tuple([root_type] + extra_binding_types)) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + +@_ods_cext.register_operation(_Dialect, replace=True) +class YieldOp(YieldOp): + def __init__( + self, + operands: Optional[Union[Operation, Sequence[Value]]] = None, + *, + loc=None, + ip=None, + ): + if operands is None: + operands = [] + super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/transform/bufferization.py b/mlir/python/mlir/dialects/transform/bufferization.py index eb77b746c..485a8a36b 100644 --- a/mlir/python/mlir/dialects/transform/bufferization.py +++ b/mlir/python/mlir/dialects/transform/bufferization.py @@ -3,3 +3,132 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._bufferization_transform_ops_gen import * +from .._bufferization_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from enum import Enum +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmptyTensorToAllocTensorOp(EmptyTensorToAllocTensorOp): + """Specialization for EmptyTensorToAllocTensorOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, target: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.OperationType.get("bufferization.alloc_tensor") + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class OneShotBufferizeOp(OneShotBufferizeOp): + """Specialization for OneShotBufferizeOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + allow_return_allocs_from_loops: Optional[bool] = None, + allow_unknown_ops: Optional[bool] = None, + bufferize_function_boundaries: Optional[bool] = None, + function_boundary_type_conversion: Optional[Enum] = None, + memcpy_op: Optional[str] = None, + print_conflicts: Optional[bool] = None, + test_analysis_only: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + + super().__init__( + transformed_type, + target, + allow_return_allocs_from_loops=allow_return_allocs_from_loops, + allow_unknown_ops=allow_unknown_ops, + bufferize_function_boundaries=bufferize_function_boundaries, + function_boundary_type_conversion=function_boundary_type_conversion, + memcpy_op=memcpy_op, + print_conflicts=print_conflicts, + test_analysis_only=test_analysis_only, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/gpu.py b/mlir/python/mlir/dialects/transform/gpu.py index 8c3de0de7..00cf0840e 100644 --- a/mlir/python/mlir/dialects/transform/gpu.py +++ b/mlir/python/mlir/dialects/transform/gpu.py @@ -3,3 +3,128 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._gpu_transform_ops_gen import * +from .._gpu_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union, overload + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapForallToBlocks(MapForallToBlocks): + """Specialization for MapForallToBlocks class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Type, Value], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + grid_dims: Optional[Union[Sequence[int], Attribute]] = None, + generate_gpu_launch: Optional[Union[bool, Attribute]] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + + super().__init__( + result_type, + target, + grid_dims=grid_dims, + generate_gpu_launch=generate_gpu_launch, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapNestedForallToThreads(MapNestedForallToThreads): + """Specialization for MapNestedForallToThreads class.""" + + @overload + def __init__( + self, + result_type: Type, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + block_dims: Optional[Sequence[int]] = None, + warp_size: Optional[Sequence[int]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + result_type_or_target: Union[Operation, OpView, Value, Type], + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + block_dims: Optional[Union[Sequence[int], Attribute]] = None, + warp_size: Optional[Union[Sequence[int], Attribute]] = None, + sync_after_distribute: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_none + else: + result_type = result_type_or_target.type + target = result_type_or_target + super().__init__( + result_type, + target, + block_dims=block_dims, + warp_size=warp_size, + sync_after_distribute=sync_after_distribute, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py index 86f72788d..6c89025f4 100644 --- a/mlir/python/mlir/dialects/transform/loop.py +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -3,3 +3,143 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._loop_transform_ops_gen import * +from .._loop_transform_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GetParentForOp(GetParentForOp): + """Extension for GetParentForOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + num_loops: Optional[int] = None, + ip=None, + loc=None, + ): + if num_loops is None: + num_loops = 1 + super().__init__( + result_type, + _get_op_result_or_value(target), + num_loops=num_loops, + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopOutlineOp(LoopOutlineOp): + """Extension for LoopOutlineOp.""" + + def __init__( + self, + function_type: Type, + call_type: Type, + target: Union[Operation, Value], + *, + func_name: Union[str, StringAttr], + ip=None, + loc=None, + ): + super().__init__( + function_type, + call_type, + _get_op_result_or_value(target), + func_name=( + func_name + if isinstance(func_name, StringAttr) + else StringAttr.get(func_name) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPeelOp(LoopPeelOp): + """Extension for LoopPeelOp.""" + + def __init__( + self, + main_loop_type: Type, + remainder_loop_type: Type, + target: Union[Operation, Value], + *, + fail_if_already_divisible: Union[bool, BoolAttr] = False, + ip=None, + loc=None, + ): + super().__init__( + main_loop_type, + remainder_loop_type, + _get_op_result_or_value(target), + fail_if_already_divisible=( + fail_if_already_divisible + if isinstance(fail_if_already_divisible, BoolAttr) + else BoolAttr.get(fail_if_already_divisible) + ), + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopPipelineOp(LoopPipelineOp): + """Extension for LoopPipelineOp.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + iteration_interval: Optional[Union[int, IntegerAttr]] = None, + read_latency: Optional[Union[int, IntegerAttr]] = None, + ip=None, + loc=None, + ): + if iteration_interval is None: + iteration_interval = 1 + if read_latency is None: + read_latency = 10 + super().__init__( + result_type, + _get_op_result_or_value(target), + iteration_interval=iteration_interval, + read_latency=read_latency, + ip=ip, + loc=loc, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class LoopUnrollOp(LoopUnrollOp): + """Extension for LoopUnrollOp.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + factor: Union[int, IntegerAttr], + ip=None, + loc=None, + ): + super().__init__( + _get_op_result_or_value(target), + factor=factor, + ip=ip, + loc=loc, + ) diff --git a/mlir/python/mlir/dialects/transform/memref.py b/mlir/python/mlir/dialects/transform/memref.py index 1ff04ef6a..56ea61eb8 100644 --- a/mlir/python/mlir/dialects/transform/memref.py +++ b/mlir/python/mlir/dialects/transform/memref.py @@ -3,3 +3,118 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._memref_transform_ops_gen import * +from .._memref_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefAllocaToGlobalOp(MemRefAllocaToGlobalOp): + """Specialization for MemRefAllocaToGlobalOp class.""" + + @overload + def __init__( + self, + get_global_type: Type, + global_type: Type, + alloca: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__(self, alloca: Union[Operation, OpView, Value], *, loc=None, ip=None): + ... + + def __init__( + self, + get_global_type_or_alloca: Union[Operation, OpView, Type, Value], + global_type_or_none: Optional[Type] = None, + alloca_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(get_global_type_or_alloca, Type): + get_global_type = get_global_type_or_alloca + global_type = global_type_or_none + alloca = alloca_or_none + else: + get_global_type = transform.AnyOpType.get() + global_type = transform.AnyOpType.get() + alloca = get_global_type_or_alloca + + super().__init__( + get_global_type, + global_type, + alloca, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MemRefMultiBufferOp(MemRefMultiBufferOp): + """Specialization for MemRefMultiBufferOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + factor: Union[int, IntegerAttr], + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_factor: Union[int, IntegerAttr, Operation, OpView, Value] = None, + factor_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + skip_analysis: Optional[bool] = None, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_factor + factor = factor_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + factor = target_or_factor + + super().__init__( + transformed_type, + target, + factor, + skip_analysis=skip_analysis, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py index b1515287a..bb5fa7ffd 100644 --- a/mlir/python/mlir/dialects/transform/pdl.py +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -3,3 +3,53 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._transform_pdl_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from ...ir import * + from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PDLMatchOp(PDLMatchOp): + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class WithPDLPatternsOp(WithPDLPatternsOp): + def __init__(self, target: Union[Operation, Value, Type], *, loc=None, ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index cb3812301..284c93823 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -3,4 +3,777 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._structured_transform_ops_gen import * +from .._structured_transform_ops_gen import _Dialect from .._structured_transform_enum_gen import * + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import List, Optional, Sequence, Tuple, Union, overload + +StaticIntLike = Union[int, IntegerAttr] +ValueLike = Union[Operation, OpView, Value] +MixedInt = Union[StaticIntLike, ValueLike] + +IntOrAttrList = Sequence[Union[IntegerAttr, int]] +OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = Sequence[Union[BoolAttr, bool]] +OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: Union[DynamicIndexList, ArrayAttr], +) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: Union[Sequence[any], ArrayAttr] +) -> Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of collection of integers, where any + Python Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class BufferizeToAllocationOp(BufferizeToAllocationOp): + """Specialization for BufferizeToAllocationOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + memory_space: Optional[Union[int, str, Attribute]] = None, + memcpy_op: Optional[str] = None, + alloc_op: Optional[str] = None, + bufferize_destination_only: Optional[bool] = None, + loc=None, + ip=None, + ): + # No other types are allowed, so hard-code those here. + allocated_buffer_type = transform.AnyValueType.get() + new_ops_type = transform.AnyOpType.get() + + if isinstance(memory_space, int): + memory_space = str(memory_space) + if isinstance(memory_space, str): + memory_space = Attribute.parse(memory_space) + + super().__init__( + allocated_buffer_type, + new_ops_type, + target, + memory_space=memory_space, + memcpy_op=memcpy_op, + alloc_op=alloc_op, + bufferize_destination_only=bufferize_destination_only, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class DecomposeOp(DecomposeOp): + """Specialization for DecomposeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseIntoContainingOp(FuseIntoContainingOp): + """Specialization for FuseIntoContainingOp class.""" + + @overload + def __init__( + self, + fused_op_type: Type, + new_containing_op_type: Type, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + producer_op: Union[Operation, OpView, Value], + containing_op: Union[Operation, OpView, Value], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + fused_op_type_or_producer_op: Union[Operation, OpView, Type, Value], + new_containing_op_type_or_containing_op: Union[Operation, OpView, Type, Value], + producer_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + containing_op_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(fused_op_type_or_producer_op, Type): + if not isinstance(new_containing_op_type_or_containing_op, Type): + raise TypeError( + "If 'fused_op_type_or_producer_op' is a type, then " + "'new_containing_op_type_or_containing_op' is expected " + "to be one as well." + ) + fused_op_type = fused_op_type_or_producer_op + new_containing_op_type = new_containing_op_type_or_containing_op + producer_op = producer_op_or_none + containing_op = containing_op_or_none + else: + fused_op_type = transform.AnyOpType.get() + new_containing_op_type = transform.AnyOpType.get() + producer_op = fused_op_type_or_producer_op + containing_op = new_containing_op_type_or_containing_op + + super().__init__( + fused_op_type, + new_containing_op_type, + producer_op, + containing_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class GeneralizeOp(GeneralizeOp): + """Specialization for GeneralizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + transformed_type = transform.AnyOpType.get() + super().__init__(transformed_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InterchangeOp(InterchangeOp): + """Specialization for InterchangeOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + iterator_interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + iterator_interchange=iterator_interchange, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MapCopyToThreadsOp(MapCopyToThreadsOp): + """Specialization for MapCopyToThreadsOp class.""" + + @overload + def __init__( + self, + forall_op_type: Type, + tiled_op_type: Type, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + ... + + def __init__( + self, + forall_op_type_or_target: Union[Operation, OpView, Type, Value], + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, OpView, Value]] = None, + *, + total_num_threads: Union[int, IntegerAttr], + desired_bit_alignment: Union[int, IntegerAttr], + loc=None, + ip=None, + ): + if isinstance(forall_op_type_or_target, Type): + forall_op_type = forall_op_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + forall_op_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = forall_op_type_or_target + + super().__init__( + forall_op_type, + tiled_op_type, + target, + total_num_threads=total_num_threads, + desired_bit_alignment=desired_bit_alignment, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeOp(VectorizeOp): + """Specialization for VectorizeOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + vector_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + *, + vectorize_nd_extract: Optional[bool] = None, + scalable_sizes: OptionalBoolList = None, + static_vector_sizes: OptionalIntList = None, + loc=None, + ip=None, + ): + if ( + scalable_sizes is None + and static_vector_sizes is None + and vector_sizes is None + ): + dynamic_vector_sizes = [] + elif scalable_sizes is None and static_vector_sizes is None: + ( + dynamic_vector_sizes, + static_vector_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(vector_sizes) + elif scalable_sizes is None or static_vector_sizes is None: + raise TypeError( + "'scalable_sizes' and 'static_vector_sizes' must either both " + "be given explicitly or both be given as part of 'vector_sizes'." + ) + else: + dynamic_vector_sizes = vector_sizes + + super().__init__( + target, + vector_sizes=dynamic_vector_sizes, + static_vector_sizes=static_vector_sizes, + scalable_sizes=scalable_sizes, + vectorize_nd_extract=vectorize_nd_extract, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MatchOp(MatchOp): + """Specialization for MatchOp class.""" + + @overload + @classmethod + def match_op_names( + cls, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @overload + @classmethod + def match_op_names( + cls, + result_type: Type, + target: Union[Operation, Value], + names: Union[str, Sequence[str]], + *, + loc=None, + ip=None, + ): + ... + + @classmethod + def match_op_names( + cls, + result_type_or_target: Union[Type, Operation, Value], + target_or_names: Union[Operation, Value, Sequence[str], str], + names_or_none: Optional[Union[Sequence[str], str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(result_type_or_target, Type): + result_type = result_type_or_target + target = target_or_names + names = names_or_none + else: + result_type = transform.AnyOpType.get() + target = result_type_or_target + names = target_or_names + + if isinstance(names, str): + names = [names] + + return cls( + result_type, + target, + ops=ArrayAttr.get(list(map(lambda s: StringAttr.get(s), names))), + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MultiTileSizesOp(MultiTileSizesOp): + """Specialization for MultiTileSizesOp class.""" + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + dimension: Union[int, IntegerAttr], + target_size: Union[int, IntegerAttr], + divisor: Optional[Optional[Union[int, IntegerAttr]]] = None, + loc=None, + ip=None, + ): + super().__init__( + result_type, + result_type, + result_type, + target, + dimension=dimension, + target_size=target_size, + divisor=divisor, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class PadOp(PadOp): + """Specialization for PadOp class.""" + + def __init__( + self, + target: Union[Operation, OpView, Value], + *, + padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, + padding_dimensions: OptionalIntList = None, + pad_to_multiple_of: OptionalIntList = None, + pack_paddings: OptionalIntList = None, + transpose_paddings: Optional[ + Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] + ] = None, + copy_back_op: Optional[Union[str, StringAttr]] = None, + loc=None, + ip=None, + ): + transpose_paddings = _get_int_array_array_attr(transpose_paddings) + + any_op_type = transform.AnyOpType.get() + super().__init__( + any_op_type, + any_op_type, + any_op_type, + target, + padding_values=padding_values, + padding_dimensions=padding_dimensions, + pad_to_multiple_of=pad_to_multiple_of, + pack_paddings=pack_paddings, + transpose_paddings=transpose_paddings, + copy_back_op=copy_back_op, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ScalarizeOp(ScalarizeOp): + """Specialization for ScalarizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + result_type = transform.AnyOpType.get() + super().__init__(result_type, target, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class SplitOp(SplitOp): + """Specialization for SplitOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + dimension: Union[int, Attribute], + split_point: Union[int, Operation, Value, Attribute], + *, + loc=None, + ip=None, + ): + if isinstance(split_point, int): + static_split_point = split_point + dynamic_split_point = None + else: + static_split_point = ShapedType.get_dynamic_size() + dynamic_split_point = split_point + + super().__init__( + target.type, + target.type, + target, + dimension=dimension, + static_split_point=static_split_point, + dynamic_split_point=dynamic_split_point, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForOp(TileUsingForOp): + """Specialization for TileUsingForOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, List[Type]], + target: Union[Operation, Value], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, List[Type], Operation, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + interchange: OptionalIntList = None, + loc=None, + ip=None, + ): + ( + dynamic_sizes, + static_sizes, + scalable_sizes, + ) = _dispatch_dynamic_index_list(sizes) + + num_loops = sum(v if v == 0 else 1 for v in static_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert ( + target_or_none is None + ), "Cannot construct TileUsingForOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + + super().__init__( + target.type, + loop_types, + target, + dynamic_sizes=dynamic_sizes, + static_sizes=static_sizes, + interchange=interchange, + scalable_sizes=scalable_sizes, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class TileUsingForallOp(TileUsingForallOp): + """Specialization for TileUsingForallOp class.""" + + @overload + def __init__( + self, + loops_type: Type, + tiled_op_type: Type, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + num_threads: Optional[MixedValues] = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loops_type_or_target: Union[ + Type, Union[Operation, Value, OpView] # loops_type + ], # target + tiled_op_type_or_none: Optional[Type] = None, + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + num_threads: MixedValues = None, + tile_sizes: MixedValues = None, + mapping=None, + loc=None, + ip=None, + ): + # `Type` arguments in the front are optional: add default values to front. + if isinstance(loops_type_or_target, Type): + # First overload: type arguments provided. + if not isinstance(tiled_op_type_or_none, Type): + raise TypeError( + "If 'loops_type_or_target' is a type, then " + "'tiled_op_type_or_none' is expected to be one as well." + ) + loops_type = loops_type_or_target + tiled_op_type = tiled_op_type_or_none + target = target_or_none + else: + # Last overload: type arguments missing. + loops_type = transform.AnyOpType.get() + tiled_op_type = transform.AnyOpType.get() + target = loops_type_or_target + + # Unpack mixed num_threads. + ( + dynamic_num_threads, + packed_num_threads, + num_threads_attr, + ) = _dispatch_mixed_values(num_threads) + + # Unpack mixed tile_sizes. + ( + dynamic_tile_sizes, + packed_tile_sizes, + tile_sizes_attr, + ) = _dispatch_mixed_values(tile_sizes) + + super().__init__( + loops_type, + tiled_op_type, + target=target, + tile_sizes=dynamic_tile_sizes, + packed_tile_sizes=packed_tile_sizes, + static_tile_sizes=tile_sizes_attr, + num_threads=dynamic_num_threads, + packed_num_threads=packed_num_threads, + static_num_threads=num_threads_attr, + mapping=mapping, + loc=loc, + ip=ip, + ) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class VectorizeChildrenAndApplyPatternsOp(VectorizeChildrenAndApplyPatternsOp): + """Specialization for VectorizeChildrenAndApplyPatternsOp class.""" + + def __init__( + self, + target: Union[Operation, Value], + *, + disable_multi_reduction_to_contract_patterns: bool = False, + disable_transfer_permutation_map_lowering_patterns: bool = False, + vectorize_nd_extract: bool = False, + vectorize_padding: bool = False, + loc=None, + ip=None, + ): + transformed_type = transform.AnyOpType.get() + super().__init__( + transformed_type, + target, + disable_multi_reduction_to_contract_patterns=disable_multi_reduction_to_contract_patterns, + disable_transfer_permutation_map_lowering_patterns=disable_transfer_permutation_map_lowering_patterns, + vectorize_nd_extract=vectorize_nd_extract, + vectorize_padding=vectorize_padding, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/tensor.py b/mlir/python/mlir/dialects/transform/tensor.py index bf52255b3..4eb30398f 100644 --- a/mlir/python/mlir/dialects/transform/tensor.py +++ b/mlir/python/mlir/dialects/transform/tensor.py @@ -3,3 +3,67 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from .._tensor_transform_ops_gen import * +from .._tensor_transform_ops_gen import _Dialect + +try: + from ...ir import * + from ...dialects import transform + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, overload, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class MakeLoopIndependentOp(MakeLoopIndependentOp): + """Specialization for MakeLoopIndependentOp class.""" + + @overload + def __init__( + self, + transformed_type: Type, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, OpView, Value], + num_loops: Union[int, IntegerAttr], + *, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + transformed_type_or_target: Type, + target_or_num_loops: Union[int, IntegerAttr, Operation, OpView, Value] = None, + num_loops_or_none: Optional[Union[int, IntegerAttr]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(transformed_type_or_target, Type): + transformed_type = transformed_type_or_target + target = target_or_num_loops + num_loops = num_loops_or_none + else: + transformed_type = transform.AnyOpType.get() + target = transformed_type_or_target + num_loops = target_or_num_loops + + super().__init__( + transformed_type, + target, + num_loops, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 0a3b41104..f6b706f9b 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -114,6 +114,7 @@ def get_unranked_memref_descriptor(nparray): d.descriptor = ctypes.cast(ctypes.pointer(x), ctypes.c_void_p) return d + def move_aligned_ptr_by_offset(aligned_ptr, offset): """Moves the supplied ctypes pointer ahead by `offset` elements.""" aligned_addr = ctypes.addressof(aligned_ptr.contents) @@ -122,6 +123,7 @@ def move_aligned_ptr_by_offset(aligned_ptr, offset): content_ptr = ctypes.cast(aligned_addr + shift, type(aligned_ptr)) return content_ptr + def unranked_memref_to_numpy(unranked_memref, np_dtype): """Converts unranked memrefs to numpy arrays.""" ctp = as_ctype(np_dtype) @@ -139,10 +141,10 @@ def unranked_memref_to_numpy(unranked_memref, np_dtype): def ranked_memref_to_numpy(ranked_memref): """Converts ranked memrefs to numpy arrays.""" - content_ptr = move_aligned_ptr_by_offset(ranked_memref[0].aligned, ranked_memref[0].offset) - np_arr = np.ctypeslib.as_array( - content_ptr, shape=ranked_memref[0].shape + content_ptr = move_aligned_ptr_by_offset( + ranked_memref[0].aligned, ranked_memref[0].offset ) + np_arr = np.ctypeslib.as_array(content_ptr, shape=ranked_memref[0].shape) strided_arr = np.lib.stride_tricks.as_strided( np_arr, np.ctypeslib.as_array(ranked_memref[0].shape), From 4e568c443b3a6b0c8aef756200793813b72c3a77 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 19 Oct 2023 18:07:06 -0500 Subject: [PATCH 602/915] [mlir][python] simplify extensions (#69642) https://github.com/llvm/llvm-project/pull/68853 enabled a lot of nice cleanup. Note, I made sure each of the touched extensions had tests. --- mlir/python/mlir/dialects/affine.py | 45 -------------- mlir/python/mlir/dialects/bufferization.py | 36 ----------- mlir/python/mlir/dialects/func.py | 3 - mlir/python/mlir/dialects/memref.py | 38 ------------ mlir/python/mlir/dialects/pdl.py | 69 ---------------------- mlir/python/mlir/dialects/scf.py | 33 +++-------- 6 files changed, 8 insertions(+), 216 deletions(-) diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 1eaccfa73..80d3873e1 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -3,48 +3,3 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._affine_ops_gen import * -from ._affine_ops_gen import _Dialect - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - _cext as _ods_cext, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -@_ods_cext.register_operation(_Dialect, replace=True) -class AffineStoreOp(AffineStoreOp): - """Specialization for the Affine store operation.""" - - def __init__( - self, - value: Union[Operation, OpView, Value], - memref: Union[Operation, OpView, Value], - map: AffineMap = None, - *, - map_operands=None, - loc=None, - ip=None, - ): - """Creates an affine store operation. - - - `value`: the value to store into the memref. - - `memref`: the buffer to store into. - - `map`: the affine map that maps the map_operands to the index of the - memref. - - `map_operands`: the list of arguments to substitute the dimensions, - then symbols in the affine map, in increasing order. - """ - map = map if map is not None else [] - map_operands = map_operands if map_operands is not None else [] - indicies = [_get_op_result_or_value(op) for op in map_operands] - _ods_successors = None - super().__init__( - value, memref, indicies, AffineMapAttr.get(map), loc=loc, ip=ip - ) diff --git a/mlir/python/mlir/dialects/bufferization.py b/mlir/python/mlir/dialects/bufferization.py index 0ce5448ac..759b6aa24 100644 --- a/mlir/python/mlir/dialects/bufferization.py +++ b/mlir/python/mlir/dialects/bufferization.py @@ -3,40 +3,4 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._bufferization_ops_gen import * -from ._bufferization_ops_gen import _Dialect from ._bufferization_enum_gen import * - -try: - from typing import Sequence, Union - from ..ir import * - from ._ods_common import get_default_loc_context, _cext as _ods_cext - - from typing import Any, List, Union -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - - -@_ods_cext.register_operation(_Dialect, replace=True) -class AllocTensorOp(AllocTensorOp): - """Extends the bufferization.alloc_tensor op.""" - - def __init__( - self, - tensor_type: Type, - dynamic_sizes: Sequence[Value], - copy: Value, - size_hint: Value, - escape: BoolAttr, - *, - loc=None, - ip=None, - ): - """Constructs an `alloc_tensor` with static and/or dynamic sizes.""" - super().__init__( - tensor_type, - dynamic_sizes, - copy=copy, - size_hint=size_hint, - loc=loc, - ip=ip, - ) diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 9c6c4c909..6599f67b7 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -26,9 +26,6 @@ class ConstantOp(ConstantOp): """Specialization for the constant op class.""" - def __init__(self, result: Type, value: Attribute, *, loc=None, ip=None): - super().__init__(result, value, loc=loc, ip=ip) - @property def type(self): return self.results[0].type diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 111ad2178..3afb6a70c 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -3,41 +3,3 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._memref_ops_gen import * -from ._memref_ops_gen import _Dialect - -try: - from ..ir import * - from ._ods_common import ( - get_op_result_or_value as _get_op_result_or_value, - get_op_results_or_values as _get_op_results_or_values, - _cext as _ods_cext, - ) -except ImportError as e: - raise RuntimeError("Error loading imports from extension module") from e - -from typing import Optional, Sequence, Union - - -@_ods_cext.register_operation(_Dialect, replace=True) -class LoadOp(LoadOp): - """Specialization for the MemRef load operation.""" - - def __init__( - self, - memref: Union[Operation, OpView, Value], - indices: Optional[Union[Operation, OpView, Sequence[Value]]] = None, - *, - loc=None, - ip=None, - ): - """Creates a memref load operation. - - Args: - memref: the buffer to load from. - indices: the list of subscripts, may be empty for zero-dimensional - buffers. - loc: user-visible location of the operation. - ip: insertion point. - """ - indices_resolved = [] if indices is None else _get_op_results_or_values(indices) - super().__init__(memref, indices_resolved, loc=loc, ip=ip) diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index a8d9c56f4..90d7d7062 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -21,43 +21,6 @@ ) -@_ods_cext.register_operation(_Dialect, replace=True) -class ApplyNativeConstraintOp(ApplyNativeConstraintOp): - """Specialization for PDL apply native constraint op class.""" - - def __init__( - self, - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(name, args, loc=loc, ip=ip) - - -@_ods_cext.register_operation(_Dialect, replace=True) -class ApplyNativeRewriteOp(ApplyNativeRewriteOp): - """Specialization for PDL apply native rewrite op class.""" - - def __init__( - self, - results: Sequence[Type], - name: Union[str, StringAttr], - args: Optional[Sequence[Union[OpView, Operation, Value]]] = None, - *, - loc=None, - ip=None, - ): - if args is None: - args = [] - args = _get_values(args) - super().__init__(results, name, args, loc=loc, ip=ip) - - @_ods_cext.register_operation(_Dialect, replace=True) class AttributeOp(AttributeOp): """Specialization for PDL attribute op class.""" @@ -75,21 +38,6 @@ def __init__( super().__init__(result, valueType=valueType, value=value, loc=loc, ip=ip) -@_ods_cext.register_operation(_Dialect, replace=True) -class EraseOp(EraseOp): - """Specialization for PDL erase op class.""" - - def __init__( - self, - operation: Optional[Union[OpView, Operation, Value]] = None, - *, - loc=None, - ip=None, - ): - operation = _get_value(operation) - super().__init__(operation, loc=loc, ip=ip) - - @_ods_cext.register_operation(_Dialect, replace=True) class OperandOp(OperandOp): """Specialization for PDL operand op class.""" @@ -216,23 +164,6 @@ def __init__( super().__init__(result, parent, index, loc=loc, ip=ip) -@_ods_cext.register_operation(_Dialect, replace=True) -class ResultsOp(ResultsOp): - """Specialization for PDL results op class.""" - - def __init__( - self, - result: Type, - parent: Union[OpView, Operation, Value], - index: Optional[Union[IntegerAttr, int]] = None, - *, - loc=None, - ip=None, - ): - parent = _get_value(parent) - super().__init__(result, parent, index=index, loc=loc, ip=ip) - - @_ods_cext.register_operation(_Dialect, replace=True) class RewriteOp(RewriteOp): """Specialization for PDL rewrite op class.""" diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 43ad9f4e2..71c80cab7 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -20,11 +20,8 @@ from typing import Optional, Sequence, Union -_ForOp = ForOp - - @_ods_cext.register_operation(_Dialect, replace=True) -class ForOp(_ForOp): +class ForOp(ForOp): """Specialization for the SCF for op class.""" def __init__( @@ -50,17 +47,8 @@ def __init__( iter_args = _get_op_results_or_values(iter_args) results = [arg.type for arg in iter_args] - super(_ForOp, self).__init__( - self.build_generic( - regions=1, - results=results, - operands=[ - _get_op_result_or_value(o) for o in [lower_bound, upper_bound, step] - ] - + list(iter_args), - loc=loc, - ip=ip, - ) + super().__init__( + results, lower_bound, upper_bound, step, iter_args, loc=loc, ip=ip ) self.regions[0].blocks.append(self.operands[0].type, *results) @@ -83,28 +71,23 @@ def inner_iter_args(self): return self.body.arguments[1:] -_IfOp = IfOp - - @_ods_cext.register_operation(_Dialect, replace=True) -class IfOp(_IfOp): +class IfOp(IfOp): """Specialization for the SCF if op class.""" - def __init__(self, cond, results_=[], *, hasElse=False, loc=None, ip=None): + def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None): """Creates an SCF `if` operation. - `cond` is a MLIR value of 'i1' type to determine which regions of code will be executed. - `hasElse` determines whether the if operation has the else branch. """ + if results_ is None: + results_ = [] operands = [] operands.append(cond) results = [] results.extend(results_) - super(_IfOp, self).__init__( - self.build_generic( - regions=2, results=results, operands=operands, loc=loc, ip=ip - ) - ) + super().__init__(results, cond) self.regions[0].blocks.append(*[]) if hasElse: self.regions[1].blocks.append(*[]) From 140e3009a04b500bf86f885f408a127a29a31843 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 20 Oct 2023 16:14:46 -0500 Subject: [PATCH 603/915] [mlir][linalg] regionBuilder for transpose, broadcast (#69742) Currently, `linalg.transpose` and `linalg.broadcast` can't be emitted through either the C API or the python bindings (which of course go through the C API). See https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10. The reason is even though they're named ops, there is no opdsl `@linalg_structured_op` for them and thus while they can be instantiated they cannot be passed to [`mlirLinalgFillBuiltinNamedOpRegion`](https://github.com/llvm/llvm-project/blob/9b231a28f6fbf6f7c1c41c40fec3489eb4dcb544/mlir/lib/CAPI/Dialect/Linalg.cpp#L18). I believe the issue is they both take a `IndexAttrDef` but `IndexAttrDef` cannot represent dynamic rank. Note, if I'm mistaken and there is a way to write the `@linalg_structured_op` let me know. The solution here simply implements the `regionBuilder` interface which is then picked up by [`LinalgDialect::addNamedOpBuilders`](https://github.com/llvm/llvm-project/blob/0c66d118e2bb286bd37dac75adaf58d6397fcd6e/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp#L116). Extension classes are added "by hand" that mirror the API of the `@linalg_structured_op`s. Note, the extension classes are added to to `dialects/linalg/__init__.py` instead of `dialects/linalg/opdsl/ops/core_named_ops.py` in order that they're not confused for opdsl generators/emitters. --- mlir/python/mlir/dialects/linalg/__init__.py | 48 ++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 1353870ec..6e4cb1bd6 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -55,3 +55,51 @@ # TODO: guard against surprises and fail create Runtime Custom Ops with # the same name as existing Core Named Ops. from .opdsl.ops.core_named_ops import * +from .opdsl.lang.emitter import isa + +from ...ir import * +from .._ods_common import get_op_result_or_value as _get_op_result_or_value + + +def transpose( + input: Union[Operation, OpView, Sequence[Value]], + *, + outs: List[Union[Operation, OpView, Sequence[Value]]], + permutation: Union[DenseI64ArrayAttr, List[int]], +): + input = _get_op_result_or_value(input) + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isa(RankedTensorType, init.type) else [] + + op = TransposeOp( + result=result_types, + input=input, + init=init, + permutation=permutation, + ) + fill_builtin_region(op.operation) + return op + + +def broadcast( + input: Union[Operation, OpView, Sequence[Value]], + *, + outs: List[Union[Operation, OpView, Sequence[Value]]], + dimensions: Union[DenseI64ArrayAttr, List[int]], +): + input = _get_op_result_or_value(input) + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isa(RankedTensorType, init.type) else [] + + op = BroadcastOp( + result=result_types, + input=input, + init=init, + dimensions=dimensions, + ) + fill_builtin_region(op.operation) + return op From 7e79c8fa39da1c46f0fd9f86acf413072b6d5792 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 20 Oct 2023 20:28:32 -0500 Subject: [PATCH 604/915] [MLIR][python bindings] invalidate ops after PassManager run (#69746) Fixes https://github.com/llvm/llvm-project/issues/69730 (also see https://reviews.llvm.org/D155543). There are two things outstanding (why I didn't land before): 1. add some C API tests for `mlirOperationWalk`; 2. potentially refactor how the invalidation in `run` works; the first version of the code looked like this: ```cpp if (invalidateOps) { auto *context = op.getOperation().getContext().get(); MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, void *userData) { PyMlirContext *context = static_cast(userData); context->setOperationInvalid(op); }; auto numRegions = mlirOperationGetNumRegions(op.getOperation().get()); for (int i = 0; i < numRegions; ++i) { MlirRegion region = mlirOperationGetRegion(op.getOperation().get(), i); for (MlirBlock block = mlirRegionGetFirstBlock(region); !mlirBlockIsNull(block); block = mlirBlockGetNextInRegion(block)) for (MlirOperation childOp = mlirBlockGetFirstOperation(block); !mlirOperationIsNull(childOp); childOp = mlirOperationGetNextInBlock(childOp)) mlirOperationWalk(childOp, invalidatingCallback, context, MlirWalkPostOrder); } } ``` This is verbose and ugly but it has the important benefit of not executing `mlirOperationEqual(rootOp->get(), op)` for every op underneath the root op. Supposing there's no desire for the slightly more efficient but highly convoluted approach, I can land this "posthaste". But, since we have eyes on this now, any suggestions or approaches (or needs/concerns) are welcome. --- mlir/include/mlir-c/IR.h | 19 ++++++++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 5 ++++ mlir/lib/Bindings/Python/IRModule.h | 5 ++++ mlir/lib/Bindings/Python/Pass.cpp | 37 +++++++++++++++++++++++------ mlir/lib/CAPI/IR/IR.cpp | 15 ++++++++++++ 5 files changed, 73 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index e361f33a0..7b121d4df 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -73,7 +73,6 @@ DEFINE_C_API_STRUCT(MlirValue, const void); /// /// A named attribute is essentially a (name, attribute) pair where the name is /// a string. - struct MlirNamedAttribute { MlirIdentifier name; MlirAttribute attribute; @@ -698,6 +697,24 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, /// ownership is transferred to the block of the other operation. MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); + +/// Traversal order for operation walk. +typedef enum MlirWalkOrder { + MlirWalkPreOrder, + MlirWalkPostOrder +} MlirWalkOrder; + +/// Operation walker type. The handler is passed an (opaque) reference to an +/// operation a pointer to a `userData`. +typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); + +/// Walks operation `op` in `walkOrder` and calls `callback` on that operation. +/// `*userData` is passed to the callback as well and can be used to tunnel some +/// some context or other data into the callback. +MLIR_CAPI_EXPORTED +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder); + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 389a4621c..a8ea1a381 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -635,6 +635,11 @@ size_t PyMlirContext::clearLiveOperations() { return numInvalidated; } +void PyMlirContext::setOperationInvalid(MlirOperation op) { + if (liveOperations.contains(op.ptr)) + liveOperations[op.ptr].second->setInvalid(); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index c5412e735..262928857 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -209,6 +209,11 @@ class PyMlirContext { /// place. size_t clearLiveOperations(); + /// Sets an operation invalid. This is useful for when some non-bindings + /// code destroys the operation and the bindings need to made aware. For + /// example, in the case when pass manager is run. + void setOperationInvalid(MlirOperation op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index cdbfcfbc2..2175cea79 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -13,6 +13,7 @@ #include "mlir-c/Pass.h" namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -63,8 +64,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(anchorOp.data(), anchorOp.size())); return new PyPassManager(passManager); }), - py::arg("anchor_op") = py::str("any"), - py::arg("context") = py::none(), + "anchor_op"_a = py::str("any"), "context"_a = py::none(), "Create a new PassManager for the current (or provided) Context.") .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) @@ -82,7 +82,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager, bool enable) { mlirPassManagerEnableVerifier(passManager.get(), enable); }, - py::arg("enable"), "Enable / disable verify-each.") + "enable"_a, "Enable / disable verify-each.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { @@ -96,7 +96,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, - py::arg("pipeline"), py::arg("context") = py::none(), + "pipeline"_a, "context"_a = py::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -111,12 +111,35 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { if (mlirLogicalResultIsFailure(status)) throw py::value_error(std::string(errorMsg.join())); }, - py::arg("pipeline"), + "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op) { + [](PyPassManager &passManager, PyOperationBase &op, + bool invalidateOps) { + if (invalidateOps) { + typedef struct { + PyOperation &rootOp; + bool rootSeen; + } callBackData; + callBackData data{op.getOperation(), false}; + // Mark all ops below the op that the passmanager will be rooted + // at (but not op itself - note the preorder) as invalid. + MlirOperationWalkCallback invalidatingCallback = + [](MlirOperation op, void *userData) { + callBackData *data = static_cast(userData); + if (LLVM_LIKELY(data->rootSeen)) + data->rootOp.getOperation() + .getContext() + ->setOperationInvalid(op); + else + data->rootSeen = true; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + static_cast(&data), MlirWalkPreOrder); + } + // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( passManager.get(), op.getOperation().get()); @@ -124,7 +147,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - py::arg("operation"), + "operation"_a, "invalidate_ops"_a = true, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index c1abbbe36..0a5151751 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -25,6 +25,7 @@ #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" +#include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" @@ -705,6 +706,20 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { return unwrap(op)->moveBefore(unwrap(other)); } +void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, + void *userData, MlirWalkOrder walkOrder) { + switch (walkOrder) { + + case MlirWalkPreOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + break; + case MlirWalkPostOrder: + unwrap(op)->walk( + [callback, userData](Operation *op) { callback(wrap(op), userData); }); + } +} + //===----------------------------------------------------------------------===// // Region API. //===----------------------------------------------------------------------===// From bd467589415044035f6aed1f94b9da908358ea71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 25 Oct 2023 07:17:56 +0200 Subject: [PATCH 605/915] [mlir][python] Clear PyOperations instead of invalidating them. (#70044) `PyOperations` are Python-level handles to `Operation *` instances. When the latter are modified by C++, the former need to be invalidated. #69746 implements such invalidation mechanism by setting all `PyReferences` to `invalid`. However, that is not enough: they also need to be removed from the `liveOperations` map since other parts of the code (such as `PyOperation::createDetached`) assume that that map only contains valid refs. This is required to actually solve the issue in #69730. --- mlir/include/mlir-c/IR.h | 4 ++-- mlir/lib/Bindings/Python/IRCore.cpp | 29 ++++++++++++++++++++++++++--- mlir/lib/Bindings/Python/IRModule.h | 14 ++++++++++---- mlir/lib/Bindings/Python/Pass.cpp | 20 +------------------- 4 files changed, 39 insertions(+), 28 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7b121d4df..5659230a0 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -705,12 +705,12 @@ typedef enum MlirWalkOrder { } MlirWalkOrder; /// Operation walker type. The handler is passed an (opaque) reference to an -/// operation a pointer to a `userData`. +/// operation and a pointer to a `userData`. typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); /// Walks operation `op` in `walkOrder` and calls `callback` on that operation. /// `*userData` is passed to the callback as well and can be used to tunnel some -/// some context or other data into the callback. +/// context or other data into the callback. MLIR_CAPI_EXPORTED void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index a8ea1a381..7cfea31db 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -635,9 +635,32 @@ size_t PyMlirContext::clearLiveOperations() { return numInvalidated; } -void PyMlirContext::setOperationInvalid(MlirOperation op) { - if (liveOperations.contains(op.ptr)) - liveOperations[op.ptr].second->setInvalid(); +void PyMlirContext::clearOperation(MlirOperation op) { + auto it = liveOperations.find(op.ptr); + if (it != liveOperations.end()) { + it->second.second->setInvalid(); + liveOperations.erase(it); + } +} + +void PyMlirContext::clearOperationsInside(PyOperationBase &op) { + typedef struct { + PyOperation &rootOp; + bool rootSeen; + } callBackData; + callBackData data{op.getOperation(), false}; + // Mark all ops below the op that the passmanager will be rooted + // at (but not op itself - note the preorder) as invalid. + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + callBackData *data = static_cast(userData); + if (LLVM_LIKELY(data->rootSeen)) + data->rootOp.getOperation().getContext()->clearOperation(op); + else + data->rootSeen = true; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + static_cast(&data), MlirWalkPreOrder); } size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 262928857..01ee4975d 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -37,6 +37,7 @@ class PyMlirContext; class DefaultingPyMlirContext; class PyModule; class PyOperation; +class PyOperationBase; class PyType; class PySymbolTable; class PyValue; @@ -209,10 +210,15 @@ class PyMlirContext { /// place. size_t clearLiveOperations(); - /// Sets an operation invalid. This is useful for when some non-bindings - /// code destroys the operation and the bindings need to made aware. For - /// example, in the case when pass manager is run. - void setOperationInvalid(MlirOperation op); + /// Removes an operation from the live operations map and sets it invalid. + /// This is useful for when some non-bindings code destroys the operation and + /// the bindings need to made aware. For example, in the case when pass + /// manager is run. + void clearOperation(MlirOperation op); + + /// Clears all operations nested inside the given op using + /// `clearOperation(MlirOperation)`. + void clearOperationsInside(PyOperationBase &op); /// Gets the count of live modules associated with this context. /// Used for testing. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 2175cea79..588a8e254 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -119,25 +119,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { [](PyPassManager &passManager, PyOperationBase &op, bool invalidateOps) { if (invalidateOps) { - typedef struct { - PyOperation &rootOp; - bool rootSeen; - } callBackData; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = - [](MlirOperation op, void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation() - .getContext() - ->setOperationInvalid(op); - else - data->rootSeen = true; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); + op.getOperation().getContext()->clearOperationsInside(op); } // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); From 0717653346641253653510b51599b0add2a1b66a Mon Sep 17 00:00:00 2001 From: bjacob Date: Wed, 25 Oct 2023 11:41:24 -0400 Subject: [PATCH 606/915] Add missing `linalg.batch_vecmat` named op (#70218) Linalg currently has these named ops: * `matmul` * `matvec` * `vecmat` * `batch_matmul` * `batch_matvec` But it does not have: * `batch_vecmat` This PRs adds that for consistency, and I have a short-term need for it ( https://github.com/openxla/iree/issues/15158 ), so not having this would cause some contortion on my end. --- .../linalg/opdsl/ops/core_named_ops.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 19734a80a..62b7da2ae 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -517,6 +517,24 @@ def batch_matvec( ) +@linalg_structured_op +def batch_vecmat( + A=TensorDef(T1, Batch, S.K), + B=TensorDef(T2, Batch, S.K, S.N), + C=TensorDef(U, Batch, S.N, output=True), +): + """Performs a batched matrix-vector multiplication. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.b, D.n, D.k) + implements(ContractionOpInterface) + C[D.b, D.n] += TypeFn.cast_signed(U, A[D.b, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] + ) + + @linalg_structured_op def dot(A=TensorDef(T1, S.M), B=TensorDef(T2, S.M), C=TensorDef(U, output=True)): """Performs a dot product of two vectors to a scalar result. From 8983367568131bb03b6b9b599d83b0ecd2af8a95 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Mon, 30 Oct 2023 12:50:37 +0100 Subject: [PATCH 607/915] [MLIR][LLVM] Change CAPI pointer factory to create opaque pointers (#70572) This commit changes the LLVM dialect's CAPI pointer getters to drop support for typed pointers. Typed pointers are deprecated and should no longer be generated. --- mlir/include/mlir-c/Dialect/LLVM.h | 2 +- mlir/lib/CAPI/Dialect/LLVM.cpp | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index ba98c33fd..72701a822 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -19,7 +19,7 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); /// Creates an llvm.ptr type. -MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirType pointee, +MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace); /// Creates an llmv.void type. diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index d023bf5d6..b4405f7aa 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -16,8 +16,8 @@ using namespace mlir::LLVM; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(LLVM, llvm, LLVMDialect) -MlirType mlirLLVMPointerTypeGet(MlirType pointee, unsigned addressSpace) { - return wrap(LLVMPointerType::get(unwrap(pointee), addressSpace)); +MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { + return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } MlirType mlirLLVMVoidTypeGet(MlirContext ctx) { From c078d4172b2d3566fe81cc75f7ad1fbe91aaf245 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Mon, 30 Oct 2023 21:46:21 +0000 Subject: [PATCH 608/915] [mlir][python] Register LLVM translations in the RegisterEverything for python (#70428) Added missing register_translations in python to replicate the same in the C-API Cleaned up the current calls to register passes where the other calls are already embedded in the mlirRegisterAllPasses. found here, https://discourse.llvm.org/t/opencl-example/74187 --- mlir/lib/Bindings/Python/RegisterEverything.cpp | 9 +++------ mlir/python/mlir/_mlir_libs/__init__.py | 8 +++++++- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp index fed5c36a6..6b2f6b0a6 100644 --- a/mlir/lib/Bindings/Python/RegisterEverything.cpp +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -7,20 +7,17 @@ //===----------------------------------------------------------------------===// #include "mlir-c/RegisterEverything.h" -#include "mlir-c/Conversion.h" -#include "mlir-c/Transforms.h" - #include "mlir/Bindings/Python/PybindAdaptors.h" PYBIND11_MODULE(_mlirRegisterEverything, m) { - m.doc() = "MLIR All Upstream Dialects and Passes Registration"; + m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; m.def("register_dialects", [](MlirDialectRegistry registry) { mlirRegisterAllDialects(registry); }); + m.def("register_llvm_translations", + [](MlirContext context) { mlirRegisterAllLLVMTranslations(context); }); // Register all passes on load. mlirRegisterAllPasses(); - mlirRegisterConversionPasses(); - mlirRegisterTransformsPasses(); } diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 03fcb1013..71c074bc9 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -83,7 +83,8 @@ def process_initializer_module(module_name): # If _mlirRegisterEverything is built, then include it as an initializer # module. - process_initializer_module("_mlirRegisterEverything") + if process_initializer_module("_mlirRegisterEverything"): + init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) # Load all _site_initialize_{i} modules, where 'i' is a number starting # at 0. @@ -102,6 +103,11 @@ def __init__(self, *args, **kwargs): # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 self.load_all_available_dialects() + if init_module: + logger.debug( + "Registering translations from initializer %r", init_module + ) + init_module.register_llvm_translations(self) ir.Context = Context From d8d6654db2cbb5e1c80274abd1e80da141d7894f Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 30 Oct 2023 20:22:27 -0500 Subject: [PATCH 609/915] [mlir][python] fix `replace=True` for `register_operation` and `register_type_caster` (#70264)

So turns out that none of the `replace=True` things actually work because of the map caches (except for `register_attribute_builder(replace=True)`, which doesn't use such a cache). This was hidden by a series of unfortunate events: 1. `register_type_caster` failure was hidden because it was the same `TestIntegerRankedTensorType` being replaced with itself (d'oh). 2. `register_operation` failure was hidden behind the "order of events" in the lifecycle of typical extension import/use. Since extensions are loaded/registered almost immediately after generated builders are registered, there is no opportunity for the `operationClassMapCache` to be populated (through e.g., `module.body.operations[2]` or `module.body.operations[2].opview` or something). Of course as soon as you as actually do "late-bind/late-register" the extension, you see it's not successfully replacing the stale one in `operationClassMapCache`. I'll take this opportunity to propose we ditch the caches all together. I've been cargo-culting them but I really don't understand how they work. There's this comment above `operationClassMapCache` ```cpp /// Cache of operation name to external operation class object. This is /// maintained on lookup as a shadow of operationClassMap in order for repeat /// lookups of the classes to only incur the cost of one hashtable lookup. llvm::StringMap operationClassMapCache; ``` But I don't understand how that's true given that the canonical thing `operationClassMap` is already a map: ```cpp /// Map of full operation name to external operation class object. llvm::StringMap operationClassMap; ``` Maybe it wasn't always the case? Anyway things work now but it seems like an unnecessary layer of complexity for not much gain? But maybe I'm wrong. --- mlir/lib/Bindings/Python/IRModule.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index a1c8ab7a0..f8e22f7bb 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -82,6 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, if (found && !found.is_none() && !replace) throw std::runtime_error("Type caster is already registered"); found = std::move(typeCaster); + const auto foundIt = typeCasterMapCache.find(mlirTypeID); + if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) { + typeCasterMapCache[mlirTypeID] = found; + } } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, @@ -104,6 +108,10 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, .str()); } found = std::move(pyClass); + auto foundIt = operationClassMapCache.find(operationName); + if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) { + operationClassMapCache[operationName] = found; + } } std::optional From ba222c434b0671dd0bb34e2042347c62233c82d8 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 31 Oct 2023 18:36:40 +0900 Subject: [PATCH 610/915] [mlir][SCF] Use `transform.get_parent_op` instead of `transform.loop.get_parent_for` (#70757) Add a new attribute to `get_parent_op` to get the n-th parent. Remove `transform.loop.get_parent_for`, which is no longer needed. --- .../mlir/dialects/transform/__init__.py | 42 ++++++++++--------- mlir/python/mlir/dialects/transform/loop.py | 24 ----------- 2 files changed, 22 insertions(+), 44 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index f7a2026e8..166c5c5ca 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -52,26 +52,28 @@ def patterns(self) -> Block: @_ods_cext.register_operation(_Dialect, replace=True) class GetParentOp(GetParentOp): - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - loc=loc, - ip=ip, - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + nth_parent: int = 1, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + nth_parent=nth_parent, + loc=loc, + ip=ip, + ) @_ods_cext.register_operation(_Dialect, replace=True) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py index 6c89025f4..3bdd9ca3b 100644 --- a/mlir/python/mlir/dialects/transform/loop.py +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -17,30 +17,6 @@ from typing import Optional, Union -@_ods_cext.register_operation(_Dialect, replace=True) -class GetParentForOp(GetParentForOp): - """Extension for GetParentForOp.""" - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - num_loops: Optional[int] = None, - ip=None, - loc=None, - ): - if num_loops is None: - num_loops = 1 - super().__init__( - result_type, - _get_op_result_or_value(target), - num_loops=num_loops, - ip=ip, - loc=loc, - ) - - @_ods_cext.register_operation(_Dialect, replace=True) class LoopOutlineOp(LoopOutlineOp): """Extension for LoopOutlineOp.""" From ea69b729fd9e7f109a83601867bb6103cfb0fba8 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 31 Oct 2023 10:29:29 -0700 Subject: [PATCH 611/915] [mlir][python] Fix possible use of variable use before set The _mlirRegisterEverything symbol may not be built by some customers. The code here was intended to support this, but didn't properly initialize the init_module variable. This would break JAX with: NameError: free variable 'init_module' referenced before assignment in enclosing scope --- mlir/python/mlir/_mlir_libs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 71c074bc9..d5fc447e4 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -83,6 +83,7 @@ def process_initializer_module(module_name): # If _mlirRegisterEverything is built, then include it as an initializer # module. + init_module = None if process_initializer_module("_mlirRegisterEverything"): init_module = importlib.import_module(f"._mlirRegisterEverything", __name__) From 5e669bfd0cf5304c482f2177b1ac73f77cd6a0a6 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 31 Oct 2023 19:55:42 -0500 Subject: [PATCH 612/915] [mlir][python] fix python_test dialect and I32/I64ElementsBuilder (#70871) This PR fixes the `I32ElementsAttr` and `I64ElementsAttr` builders and tests them through the `python_test` dialect. --- mlir/python/mlir/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 43553f311..cf4228c2a 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -277,7 +277,7 @@ def _f64ElementsAttr(x, context): def _i32ElementsAttr(x, context): return DenseElementsAttr.get( np.array(x, dtype=np.int32), - type=IntegerType.get_signed(32, context=context), + type=IntegerType.get_signless(32, context=context), context=context, ) @@ -285,7 +285,7 @@ def _i32ElementsAttr(x, context): def _i64ElementsAttr(x, context): return DenseElementsAttr.get( np.array(x, dtype=np.int64), - type=IntegerType.get_signed(64, context=context), + type=IntegerType.get_signless(64, context=context), context=context, ) From 43dae63646255ec3c2049bba917d733ed5a8547c Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 3 Nov 2023 13:28:20 -0500 Subject: [PATCH 613/915] [mlir][python] remove various caching mechanisms (#70831) This PR removes the various caching mechanisms currently in the python bindings - both positive caching and negative caching. --- mlir/lib/Bindings/Python/Globals.h | 24 +--- mlir/lib/Bindings/Python/IRModule.cpp | 131 +++++------------- mlir/lib/Bindings/Python/MainModule.cpp | 11 +- .../python/mlir/_mlir_libs/_mlir/__init__.pyi | 1 + 4 files changed, 54 insertions(+), 113 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 21899bdce..976297257 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,10 +9,6 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H -#include -#include -#include - #include "PybindUtils.h" #include "mlir-c/IR.h" @@ -21,6 +17,10 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include +#include +#include + namespace mlir { namespace python { @@ -45,17 +45,13 @@ class PyGlobals { dialectSearchPrefixes.swap(newValues); } - /// Clears positive and negative caches regarding what implementations are - /// available. Future lookups will do more expensive existence checks. - void clearImportCache(); - /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises /// an error on any evaluation issues. /// Note that this returns void because it is expected that the module /// contains calls to decorators and helpers that register the salient - /// entities. - void loadDialectModule(llvm::StringRef dialectNamespace); + /// entities. Returns true if dialect is successfully loaded. + bool loadDialectModule(llvm::StringRef dialectNamespace); /// Adds a user-friendly Attribute builder. /// Raises an exception if the mapping already exists and replace == false. @@ -113,16 +109,10 @@ class PyGlobals { llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. llvm::DenseMap typeCasterMap; - /// Cache for map of MlirTypeID to custom type caster. - llvm::DenseMap typeCasterMapCache; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. - llvm::StringSet<> loadedDialectModulesCache; - /// Cache of operation name to external operation class object. This is - /// maintained on lookup as a shadow of operationClassMap in order for repeat - /// lookups of the classes to only incur the cost of one hashtable lookup. - llvm::StringMap operationClassMapCache; + llvm::StringSet<> loadedDialectModules; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index f8e22f7bb..6c5cde862 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -10,12 +10,12 @@ #include "Globals.h" #include "PybindUtils.h" -#include -#include - #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" +#include +#include + namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -36,12 +36,12 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } -void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - if (loadedDialectModulesCache.contains(dialectNamespace)) - return; +bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { + if (loadedDialectModules.contains(dialectNamespace)) + return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded; + py::object loaded = py::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); @@ -57,15 +57,18 @@ void PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { break; } + if (loaded.is_none()) + return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. - loadedDialectModulesCache.insert(dialectNamespace); + loadedDialectModules.insert(dialectNamespace); + return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, py::function pyFunc, bool replace) { py::object &found = attributeBuilderMap[attributeKind]; - if (found && !found.is_none() && !replace) { + if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + @@ -79,13 +82,10 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, bool replace) { pybind11::object &found = typeCasterMap[mlirTypeID]; - if (found && !found.is_none() && !replace) - throw std::runtime_error("Type caster is already registered"); + if (found && !replace) + throw std::runtime_error("Type caster is already registered with caster: " + + py::str(found).operator std::string()); found = std::move(typeCaster); - const auto foundIt = typeCasterMapCache.find(mlirTypeID); - if (foundIt != typeCasterMapCache.end() && !foundIt->second.is_none()) { - typeCasterMapCache[mlirTypeID] = found; - } } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, @@ -108,114 +108,59 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, .str()); } found = std::move(pyClass); - auto foundIt = operationClassMapCache.find(operationName); - if (foundIt != operationClassMapCache.end() && !foundIt->second.is_none()) { - operationClassMapCache[operationName] = found; - } } std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { - // Fast match against the class map first (common case). const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::function is defined"); + assert(foundIt->second && "attribute builder is defined"); return foundIt->second; } - - // Not found and loading did not yield a registration. Negative cache. - attributeBuilderMap[attributeKind] = py::none(); return std::nullopt; } std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { - { - // Fast match against the class map first (common case). - const auto foundIt = typeCasterMapCache.find(mlirTypeID); - if (foundIt != typeCasterMapCache.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::function is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. - loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - - // Attempt to find from the canonical map and cache. - { - const auto foundIt = typeCasterMap.find(mlirTypeID); - if (foundIt != typeCasterMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - typeCasterMapCache[mlirTypeID] = foundIt->second; - return foundIt->second; - } - // Negative cache. - typeCasterMap[mlirTypeID] = py::none(); + // Make sure dialect module is loaded. + if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)))) return std::nullopt; + + const auto foundIt = typeCasterMap.find(mlirTypeID); + if (foundIt != typeCasterMap.end()) { + assert(foundIt->second && "type caster is defined"); + return foundIt->second; } + return std::nullopt; } std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { - loadDialectModule(dialectNamespace); - // Fast match against the class map first (common case). + // Make sure dialect module is loaded. + if (!loadDialectModule(dialectNamespace)) + return std::nullopt; const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::object is defined"); + assert(foundIt->second && "dialect class is defined"); return foundIt->second; } - - // Not found and loading did not yield a registration. Negative cache. - dialectClassMap[dialectNamespace] = py::none(); + // Not found and loading did not yield a registration. return std::nullopt; } std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { - { - auto foundIt = operationClassMapCache.find(operationName); - if (foundIt != operationClassMapCache.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::object is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. + // Make sure dialect module is loaded. auto split = operationName.split('.'); llvm::StringRef dialectNamespace = split.first; - loadDialectModule(dialectNamespace); - - // Attempt to find from the canonical map and cache. - { - auto foundIt = operationClassMap.find(operationName); - if (foundIt != operationClassMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - operationClassMapCache[operationName] = foundIt->second; - return foundIt->second; - } - // Negative cache. - operationClassMap[operationName] = py::none(); + if (!loadDialectModule(dialectNamespace)) return std::nullopt; - } -} -void PyGlobals::clearImportCache() { - loadedDialectModulesCache.clear(); - operationClassMapCache.clear(); - typeCasterMapCache.clear(); + auto foundIt = operationClassMap.find(operationName); + if (foundIt != operationClassMap.end()) { + assert(foundIt->second && "OpView is defined"); + return foundIt->second; + } + // Not found and loading did not yield a registration. + return std::nullopt; } diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index a936becf6..2ba3a3677 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,14 +6,14 @@ // //===----------------------------------------------------------------------===// -#include - #include "PybindUtils.h" #include "Globals.h" #include "IRModule.h" #include "Pass.h" +#include + namespace py = pybind11; using namespace mlir; using namespace py::literals; @@ -34,9 +34,14 @@ PYBIND11_MODULE(_mlir, m) { "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - self.clearImportCache(); }, "module_name"_a) + .def( + "_check_dialect_module_loaded", + [](PyGlobals &self, const std::string &dialectNamespace) { + return self.loadDialectModule(dialectNamespace); + }, + "dialect_namespace"_a) .def("_register_dialect_impl", &PyGlobals::registerDialectImpl, "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 93b98c4aa..3ed1872f1 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -7,6 +7,7 @@ class _Globals: def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... def append_dialect_search_prefix(self, module_name: str) -> None: ... + def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... def register_dialect(dialect_class: type) -> object: ... def register_operation(dialect_class: type) -> object: ... From a5753e5cbe35d62a008156e0e61ef6c0f30e2a5a Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 3 Nov 2023 16:29:03 -0700 Subject: [PATCH 614/915] [mlir][c] Add mlirOperationPrintWithState Enable passing in MlirAsmState optionally (allow for passing in null) to allow using the more efficient print calling API. The existing print behavior results in a new AsmState is implicitly created by walking the parent op and renumbering values. This makes the cost more explicit and avoidable (by reusing an AsmState). --- mlir/include/mlir-c/IR.h | 7 +++++++ mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 5659230a0..413eaa6aa 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -667,6 +667,13 @@ MLIR_CAPI_EXPORTED void mlirOperationPrintWithFlags(MlirOperation op, MlirStringCallback callback, void *userData); +/// Same as mlirOperationPrint but accepts AsmState controlling the printing +/// behavior as well as caching computed names. +MLIR_CAPI_EXPORTED void mlirOperationPrintWithState(MlirOperation op, + MlirAsmState state, + MlirStringCallback callback, + void *userData); + /// Same as mlirOperationPrint but writing the bytecode format. MLIR_CAPI_EXPORTED void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 0a5151751..d1ee1b774 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -678,6 +678,14 @@ void mlirOperationPrintWithFlags(MlirOperation op, MlirOpPrintingFlags flags, unwrap(op)->print(stream, *unwrap(flags)); } +void mlirOperationPrintWithState(MlirOperation op, MlirAsmState state, + MlirStringCallback callback, void *userData) { + detail::CallbackOstream stream(callback, userData); + if (state.ptr) + unwrap(op)->print(stream, *unwrap(state)); + unwrap(op)->print(stream); +} + void mlirOperationWriteBytecode(MlirOperation op, MlirStringCallback callback, void *userData) { detail::CallbackOstream stream(callback, userData); From b43f4246daa55a07b390c2fee3259a88f1dbf2a5 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Mon, 6 Nov 2023 13:14:56 +0100 Subject: [PATCH 615/915] [mlir] support scalable vectors in python bindings (#71050) The scalable dimension functionality was added to the vector type after the bindings for it were defined, without the bindings being ever updated. Fix that. --- mlir/include/mlir-c/BuiltinTypes.h | 26 +++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 69 ++++++++++++++++++++++------ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 25 ++++++++++ 3 files changed, 107 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index a6d8e10ef..1fd5691f4 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -271,6 +271,32 @@ MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetChecked(MlirLocation loc, const int64_t *shape, MlirType elementType); +/// Creates a scalable vector type with the shape identified by its rank and +/// dimensions. A subset of dimensions may be marked as scalable via the +/// corresponding flag list, which is expected to have as many entries as the +/// rank of the vector. The vector is created in the same context as the element +/// type. +MLIR_CAPI_EXPORTED MlirType mlirVectorTypeGetScalable(intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType); + +/// Same as "mlirVectorTypeGetScalable" but returns a nullptr wrapping MlirType +/// on illegal arguments, emitting appropriate diagnostics. +MLIR_CAPI_EXPORTED +MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType); + +/// Checks whether the given vector type is scalable, i.e., has at least one +/// scalable dimension. +MLIR_CAPI_EXPORTED bool mlirVectorTypeIsScalable(MlirType type); + +/// Checks whether the "dim"-th dimension of the given vector is scalable. +MLIR_CAPI_EXPORTED bool mlirVectorTypeIsDimScalable(MlirType type, + intptr_t dim); + //===----------------------------------------------------------------------===// // Ranked / Unranked Tensor type. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index a7ccfbea5..483db673f 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -462,19 +462,62 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector shape, PyType &elementType, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), t); - }, - py::arg("shape"), py::arg("elementType"), py::arg("loc") = py::none(), - "Create a vector type"); + c.def_static("get", &PyVectorType::get, py::arg("shape"), + py::arg("elementType"), py::kw_only(), + py::arg("scalable") = py::none(), + py::arg("scalable_dims") = py::none(), + py::arg("loc") = py::none(), "Create a vector type") + .def_property_readonly( + "scalable", + [](MlirType self) { return mlirVectorTypeIsScalable(self); }) + .def_property_readonly("scalable_dims", [](MlirType self) { + std::vector scalableDims; + size_t rank = static_cast(mlirShapedTypeGetRank(self)); + scalableDims.reserve(rank); + for (size_t i = 0; i < rank; ++i) + scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); + return scalableDims; + }); + } + +private: + static PyVectorType get(std::vector shape, PyType &elementType, + std::optional scalable, + std::optional> scalableDims, + DefaultingPyLocation loc) { + if (scalable && scalableDims) { + throw py::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw py::value_error("Expected len(scalable) == len(shape)."); + + SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( + *scalable, [](const py::handle &h) { return h.cast(); })); + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else if (scalableDims) { + SmallVector scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) + throw py::value_error("Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else { + type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); } }; diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 50266b4b5..6e645188d 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -281,6 +281,31 @@ MlirType mlirVectorTypeGetChecked(MlirLocation loc, intptr_t rank, unwrap(elementType))); } +MlirType mlirVectorTypeGetScalable(intptr_t rank, const int64_t *shape, + const bool *scalable, MlirType elementType) { + return wrap(VectorType::get( + llvm::ArrayRef(shape, static_cast(rank)), unwrap(elementType), + llvm::ArrayRef(scalable, static_cast(rank)))); +} + +MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, + const int64_t *shape, + const bool *scalable, + MlirType elementType) { + return wrap(VectorType::getChecked( + unwrap(loc), llvm::ArrayRef(shape, static_cast(rank)), + unwrap(elementType), + llvm::ArrayRef(scalable, static_cast(rank)))); +} + +bool mlirVectorTypeIsScalable(MlirType type) { + return unwrap(type).cast().isScalable(); +} + +bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { + return unwrap(type).cast().getScalableDims()[dim]; +} + //===----------------------------------------------------------------------===// // Ranked / Unranked tensor type. //===----------------------------------------------------------------------===// From 75ec7008e442f86f6b2f9966169b80828ae93a5b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 7 Nov 2023 10:49:41 -0600 Subject: [PATCH 616/915] [mlir][python] value casting (#69644) This PR adds "value casting", i.e., a mechanism to wrap `ir.Value` in a proxy class that overloads dunders such as `__add__`, `__sub__`, and `__mul__` for fun and great profit. This is thematically similar to https://github.com/llvm/llvm-project/commit/ac4ebc4fc044f0caac4602362ac978d46b9026dc and https://github.com/llvm/llvm-project/commit/67a748b04e17a2cf4d9af967d8b27f74e1c71efc. The example in the test demonstrates the value of the feature (no pun intended): ```python @register_value_caster(F16Type.static_typeid) @register_value_caster(F32Type.static_typeid) @register_value_caster(F64Type.static_typeid) @register_value_caster(IntegerType.static_typeid) class ArithValue(Value): __add__ = partialmethod(_binary_op, op="add") __sub__ = partialmethod(_binary_op, op="sub") __mul__ = partialmethod(_binary_op, op="mul") a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) b = a + a # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) print(b) a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) b = a - a # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) print(b) a = arith.constant(value=FloatAttr.get(f64_t, 42.42)) b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b) ``` **EDIT**: this now goes through the bindings and thus supports automatic casting of `OpResult` (including as an element of `OpResultList`), `BlockArgument` (including as an element of `BlockArgumentList`), as well as `Value`. --- mlir/include/mlir-c/Bindings/Python/Interop.h | 23 +++++++++++--- .../mlir/Bindings/Python/PybindAdaptors.h | 10 +++--- mlir/lib/Bindings/Python/Globals.h | 14 ++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 31 ++++++++++++++++--- mlir/lib/Bindings/Python/IRModule.cpp | 21 +++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 14 ++++++--- mlir/lib/Bindings/Python/MainModule.cpp | 30 +++++++++++++----- mlir/lib/Bindings/Python/PybindUtils.h | 15 +++++++-- mlir/python/mlir/dialects/_ods_common.py | 13 +++++++- mlir/python/mlir/ir.py | 2 +- 10 files changed, 142 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index f79c10cb9..0a36e97c2 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -118,13 +118,28 @@ /** Attribute on main C extension module (_mlir) that corresponds to the * type caster registration binding. The signature of the function is: - * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster, - * bool replace) - * where replace indicates the typeCaster should replace any existing registered - * type casters (such as those for upstream ConcreteTypes). + * def register_type_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a typeCaster (register_type_caster is meant to be used as a + * decorator from python), and where replace indicates the typeCaster should + * replace any existing registered type casters (such as those for upstream + * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type) + * -> SubClassTypeT where SubClassTypeT indicates the result should be a + * subclass (inherit from) ir.Type. */ #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" +/** Attribute on main C extension module (_mlir) that corresponds to the + * value caster registration binding. The signature of the function is: + * def register_value_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a valueCaster (register_value_caster is meant to be used as + * a decorator, from python), and where replace indicates the valueCaster should + * replace any existing registered value casters. The interface of the + * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where + * SubClassValueT indicates the result should be a subclass (inherit from) + * ir.Value. + */ +#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster" + /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 49680c8b7..5e0e56fc0 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -234,6 +234,7 @@ struct type_caster { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Value") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); }; }; @@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass { if (getTypeIDFunction) { py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction(), - pybind11::cpp_function( - [thisClass = thisClass](const py::object &mlirType) { - return thisClass(mlirType); - })); + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); } } }; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 976297257..a022067f5 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -66,6 +66,13 @@ class PyGlobals { void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, bool replace = false); + /// Adds a user-friendly value caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace = false); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -86,6 +93,10 @@ class PyGlobals { std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); + /// Returns the custom value caster for MlirTypeID mlirTypeID. + std::optional lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -109,7 +120,8 @@ class PyGlobals { llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. llvm::DenseMap typeCasterMap; - + /// Map of MlirTypeID to custom value caster. + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7cfea31db..0f2ca666c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const { } //------------------------------------------------------------------------------ -// PyValue and subclases. +// PyValue and subclasses. //------------------------------------------------------------------------------ pybind11::object PyValue::getCapsule() { return py::reinterpret_steal(mlirPythonValueToCapsule(get())); } +pybind11::object PyValue::maybeDownCast() { + MlirType type = mlirValueGetType(get()); + MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional valueCaster = + PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); + // py::return_value_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. + py::object thisObj = py::cast(this, py::return_value_policy::move); + if (!valueCaster) + return thisObj; + return valueCaster.value()(thisObj); +} + PyValue PyValue::createFromCapsule(pybind11::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) @@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue { return DerivedTy::isaFunction(otherValue); }, py::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); } @@ -2193,6 +2210,7 @@ class PyBlockArgumentList : public Sliceable { public: static constexpr const char *pyClassName = "BlockArgumentList"; + using SliceableT = Sliceable; PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex = 0, intptr_t length = -1, @@ -2241,6 +2259,7 @@ class PyBlockArgumentList class PyOpOperandList : public Sliceable { public: static constexpr const char *pyClassName = "OpOperandList"; + using SliceableT = Sliceable; PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2296,6 +2315,7 @@ class PyOpOperandList : public Sliceable { class PyOpResultList : public Sliceable { public: static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2303,7 +2323,7 @@ class PyOpResultList : public Sliceable { length == -1 ? mlirOperationGetNumResults(operation->get()) : length, step), - operation(operation) {} + operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyOpResultList &self) { @@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) { .str()); } return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)); + mlirOperationGetResult(operation, 0)) + .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") @@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) { [](PyValue &self, PyValue &with) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, - kValueReplaceAllUsesWithDocstring); + kValueReplaceAllUsesWithDocstring) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) { return self.maybeDownCast(); }); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6c5cde862..5538924d2 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, found = std::move(typeCaster); } +void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace) { + pybind11::object &found = valueCasterMap[mlirTypeID]; + if (found && !replace) + throw std::runtime_error("Value caster is already registered: " + + py::repr(found).cast()); + found = std::move(valueCaster); +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -134,6 +144,17 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + const auto foundIt = valueCasterMap.find(mlirTypeID); + if (foundIt != valueCasterMap.end()) { + assert(foundIt->second && "value caster is defined"); + return foundIt->second; + } + return std::nullopt; +} + std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 01ee4975d..af55693f1 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -761,7 +761,7 @@ class PyRegion { /// Wrapper around an MlirAsmState. class PyAsmState { - public: +public: PyAsmState(MlirValue value, bool useLocalScope) { flags = mlirOpPrintingFlagsCreate(); // The OpPrintingFlags are not exposed Python side, create locally and @@ -780,16 +780,14 @@ class PyAsmState { state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); } - ~PyAsmState() { - mlirOpPrintingFlagsDestroy(flags); - } + ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } // Delete copy constructors. PyAsmState(PyAsmState &other) = delete; PyAsmState(const PyAsmState &other) = delete; MlirAsmState get() { return state; } - private: +private: MlirAsmState state; MlirOpPrintingFlags flags; }; @@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy { /// bindings so such operation always exists). class PyValue { public: + // The virtual here is "load bearing" in that it enables RTTI + // for PyConcreteValue CRTP classes that support maybeDownCast. + // See PyValue::maybeDownCast. + virtual ~PyValue() = default; PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(std::move(parentOperation)), value(value) {} operator MlirValue() const { return value; } @@ -1124,6 +1126,8 @@ class PyValue { /// Gets a capsule wrapping the void* within the MlirValue. pybind11::object getCapsule(); + pybind11::object maybeDownCast(); + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. static PyValue createFromCapsule(pybind11::object capsule); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 2ba3a3677..17272472c 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,8 +12,6 @@ #include "IRModule.h" #include "Pass.h" -#include - namespace py = pybind11; using namespace mlir; using namespace py::literals; @@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, "replace"_a = false, + "operation_name"_a, "operation_class"_a, py::kw_only(), + "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "dialect_class"_a, "replace"_a = false, + "dialect_class"_a, py::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) { - PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster), - replace); + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function([mlirTypeID, + replace](py::object typeCaster) -> py::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); }, - "typeid"_a, "type_caster"_a, "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); + m.def( + MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function( + [mlirTypeID, replace](py::object valueCaster) -> py::object { + PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, + replace); + return valueCaster; + }); + }, + "typeid"_a, py::kw_only(), "replace"_a = false, + "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 2a8da20be..38462ac8b 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -10,6 +10,7 @@ #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #include "mlir-c/Support.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" @@ -228,6 +229,11 @@ class Sliceable { return linearIndex; } + /// Trait to check if T provides a `maybeDownCast` method. + /// Note, you need the & to detect inherited members. + template + using has_maybe_downcast = decltype(&T::maybeDownCast); + /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. @@ -239,8 +245,13 @@ class Sliceable { return {}; } - return pybind11::cast( - static_cast(this)->getRawElement(linearizeIndex(index))); + if constexpr (llvm::is_detected::value) + return static_cast(this) + ->getRawElement(linearizeIndex(index)) + .maybeDownCast(); + else + return pybind11::cast( + static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 9cca7d659..60ce83c09 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -5,7 +5,12 @@ # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. from .._mlir_libs import _mlir as _cext -from typing import Sequence as _Sequence, Union as _Union +from typing import ( + Sequence as _Sequence, + Type as _Type, + TypeVar as _TypeVar, + Union as _Union, +) __all__ = [ "equally_sized_accessor", @@ -123,3 +128,9 @@ def get_op_result_or_op_results( if len(op.results) > 0 else op ) + + +# This is the standard way to indicate subclass/inheritance relationship +# see the typing.Type doc string. +_U = _TypeVar("_U", bound=_cext.ir.Value) +SubClassValueT = _Type[_U] diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index cf4228c2a..18526ab8c 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,7 +4,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster +from ._mlir_libs._mlir import register_type_caster, register_value_caster # Convenience decorator for registering user-friendly Attribute builders. From 5c63728961a87706a5378aa21b141895fca95d23 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 7 Nov 2023 19:52:43 -0600 Subject: [PATCH 617/915] [mlir][python] factor out pure python core sources (#71592) I'd like to be able to install just the Python core sources (without building/including the pybind sources). --- mlir/python/CMakeLists.txt | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 88e6e1360..971ad2dd2 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -7,14 +7,16 @@ include(AddMLIRPython) declare_mlir_python_sources(MLIRPythonSources) declare_mlir_python_sources(MLIRPythonSources.Dialects ADD_TO_PARENT MLIRPythonSources) +declare_mlir_python_sources(MLIRPythonSources.Core + ADD_TO_PARENT MLIRPythonSources) ################################################################################ # Pure python sources and generated code ################################################################################ -declare_mlir_python_sources(MLIRPythonSources.Core +declare_mlir_python_sources(MLIRPythonSources.Core.Python ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" - ADD_TO_PARENT MLIRPythonSources + ADD_TO_PARENT MLIRPythonSources.Core SOURCES _mlir_libs/__init__.py ir.py From 86d1eef1b695b0efef4f8bf876df79e50ead67c7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 8 Nov 2023 09:49:57 +0100 Subject: [PATCH 618/915] [mlir][python]Add sugared buider for transform.named_sequence (#71597) --- .../mlir/dialects/transform/__init__.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 166c5c5ca..23b278d37 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] +@_ods_cext.register_operation(_Dialect, replace=True) +class NamedSequenceOp(NamedSequenceOp): + def __init__( + self, + sym_name, + input_types: Sequence[Type], + result_types: Sequence[Type], + ): + function_type = FunctionType.get(input_types, result_types) + super().__init__( + sym_name=sym_name, + function_type=TypeAttr.get(function_type), + ) + self.regions[0].blocks.append(*input_types) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + @_ods_cext.register_operation(_Dialect, replace=True) class YieldOp(YieldOp): def __init__( From c601972503d40029d0af2e97504197e119001b21 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 8 Nov 2023 09:12:24 +0000 Subject: [PATCH 619/915] Revert "[mlir][python]Add sugared buider for transform.named_sequence (#71597)" This reverts commit 86d1eef1b695b0efef4f8bf876df79e50ead67c7. --- .../mlir/dialects/transform/__init__.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 23b278d37..166c5c5ca 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -165,34 +165,6 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] -@_ods_cext.register_operation(_Dialect, replace=True) -class NamedSequenceOp(NamedSequenceOp): - def __init__( - self, - sym_name, - input_types: Sequence[Type], - result_types: Sequence[Type], - ): - function_type = FunctionType.get(input_types, result_types) - super().__init__( - sym_name=sym_name, - function_type=TypeAttr.get(function_type), - ) - self.regions[0].blocks.append(*input_types) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - @property - def bodyExtraArgs(self) -> BlockArgumentList: - return self.body.arguments[1:] - - @_ods_cext.register_operation(_Dialect, replace=True) class YieldOp(YieldOp): def __init__( From 38d1003d1fefeb7ac23bcb553a3e0b80ca5dcb56 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 8 Nov 2023 09:49:57 +0100 Subject: [PATCH 620/915] [mlir][python] Reland - Add sugared builder for transform.named_sequence Address issues with #71597 post-revert and and reland --- .../mlir/dialects/transform/__init__.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 166c5c5ca..23b278d37 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -165,6 +165,34 @@ def bodyExtraArgs(self) -> BlockArgumentList: return self.body.arguments[1:] +@_ods_cext.register_operation(_Dialect, replace=True) +class NamedSequenceOp(NamedSequenceOp): + def __init__( + self, + sym_name, + input_types: Sequence[Type], + result_types: Sequence[Type], + ): + function_type = FunctionType.get(input_types, result_types) + super().__init__( + sym_name=sym_name, + function_type=TypeAttr.get(function_type), + ) + self.regions[0].blocks.append(*input_types) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] + + @property + def bodyExtraArgs(self) -> BlockArgumentList: + return self.body.arguments[1:] + + @_ods_cext.register_operation(_Dialect, replace=True) class YieldOp(YieldOp): def __init__( From 480935b366089062573c928996942741f5fcb9f2 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 8 Nov 2023 13:17:25 +0000 Subject: [PATCH 621/915] [mlir][python] Add support for arg_attrs and other attrs to NamedSequenceOp --- mlir/python/mlir/dialects/transform/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 23b278d37..1dca1a66b 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -172,11 +172,17 @@ def __init__( sym_name, input_types: Sequence[Type], result_types: Sequence[Type], + sym_visibility=None, + arg_attrs=None, + res_attrs=None ): function_type = FunctionType.get(input_types, result_types) super().__init__( sym_name=sym_name, function_type=TypeAttr.get(function_type), + sym_visibility=sym_visibility, + arg_attrs=arg_attrs, + res_attrs=res_attrs ) self.regions[0].blocks.append(*input_types) From c9a95d162ef805a0738da176457b930b74b8e6bd Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 8 Nov 2023 16:59:17 -0600 Subject: [PATCH 622/915] [mlir][cmake] export list of CAPI libs (#71722) --- mlir/lib/CAPI/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 052eff327..707e78ac3 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -1,10 +1,10 @@ -# For upstream, we accumulate all libraries into the MLIR_CAPI_LIBRARIES +# For upstream, we accumulate all libraries into the MLIR_CAPI_LIBS # property via a custom wrapper function. This is then used to create an # aggregate below. -set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBRARIES) +set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBS) function(add_mlir_upstream_c_api_library name) add_mlir_public_c_api_library(${name} ${ARGN}) - set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBRARIES ${name}) + set_property(GLOBAL APPEND PROPERTY MLIR_CAPI_LIBS ${name}) endfunction() add_subdirectory(Debug) @@ -22,7 +22,7 @@ endif() # Build the optional CAPI dylib. if(MLIR_BUILD_MLIR_C_DYLIB) message(STATUS "Building MLIR-C dylib") - get_property(_capi_libraries GLOBAL PROPERTY MLIR_CAPI_LIBRARIES) + get_property(_capi_libraries GLOBAL PROPERTY MLIR_CAPI_LIBS) add_mlir_aggregate(MLIR-C SHARED EMBED_LIBS From b12b590c56406eed863bf3a9106a966e12469143 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 11 Nov 2023 21:41:56 -0800 Subject: [PATCH 623/915] [mlir][python] Allow contexts to be created with a custom thread pool. (#72042) The existing initialization sequence always enables multi-threading at MLIRContext construction time, making it impractical to provide a customized thread pool. Here, this is changed to always create the context with threading disabled, process all site-specific init hooks (which can set thread pools) and ultimately enable multi-threading unless if site-configured to not do so. This should preserve the existing user-visible initialization behavior while also letting downstreams ensure that contexts are always created with a shared thread pool. This was tested with IREE, which has such a concept. Using site-specific thread tuning produced up to 2x single compilation job behavior and customization of batch compilation (i.e. as part of a build system) to utilize half the memory and run the entire test suite ~2x faster. Given this, I believe that the additional configurability can well pay for itself for implementations that use it. We may also want to present user-level Python APIs for controlling threading configuration in the future. --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- mlir/python/mlir/_mlir_libs/__init__.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 0f2ca666c..745aa64e6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -597,7 +597,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) { } PyMlirContext *PyMlirContext::createNewContextForInit() { - MlirContext context = mlirContextCreate(); + MlirContext context = mlirContextCreateWithThreading(false); return new PyMlirContext(context); } diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index d5fc447e4..6ce77b4cb 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -46,6 +46,13 @@ def get_include_dirs() -> Sequence[str]: # c. If the module has a 'context_init_hook', it will be added to a list # of callbacks that are invoked as the last step of Context # initialization (and passed the Context under construction). +# d. If the module has a 'disable_multithreading' attribute, it will be +# taken as a boolean. If it is True for any initializer, then the +# default behavior of enabling multithreading on the context +# will be suppressed. This complies with the original behavior of all +# contexts being created with multithreading enabled while allowing +# this behavior to be changed if needed (i.e. if a context_init_hook +# explicitly sets up multithreading). # # This facility allows downstreams to customize Context creation to their # needs. @@ -58,8 +65,10 @@ def _site_initialize(): logger = logging.getLogger(__name__) registry = ir.DialectRegistry() post_init_hooks = [] + disable_multithreading = False def process_initializer_module(module_name): + nonlocal disable_multithreading try: m = importlib.import_module(f".{module_name}", __name__) except ModuleNotFoundError: @@ -79,6 +88,10 @@ def process_initializer_module(module_name): if hasattr(m, "context_init_hook"): logger.debug("Adding context init hook from %r", m) post_init_hooks.append(m.context_init_hook) + if hasattr(m, "disable_multithreading"): + if bool(m.disable_multithreading): + logger.debug("Disabling multi-threading for context") + disable_multithreading = True return True # If _mlirRegisterEverything is built, then include it as an initializer @@ -100,6 +113,8 @@ def __init__(self, *args, **kwargs): self.append_dialect_registry(registry) for hook in post_init_hooks: hook(self) + if not disable_multithreading: + self.enable_multithreading(True) # TODO: There is some debate about whether we should eagerly load # all dialects. It is being done here in order to preserve existing # behavior. See: https://github.com/llvm/llvm-project/issues/56037 From f47dfa1a185af9b23c1236660082dc25c55103db Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:12:39 -0700 Subject: [PATCH 624/915] [MLIR] Apply clang-tidy fixes for misc-include-cleaner (NFC) --- mlir/lib/Bindings/Python/AsyncPasses.cpp | 1 + mlir/lib/Bindings/Python/DialectPDL.cpp | 4 ++++ mlir/lib/Bindings/Python/DialectQuant.cpp | 5 +++++ mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 6 ++++++ mlir/lib/Bindings/Python/DialectTransform.cpp | 5 +++++ mlir/lib/Bindings/Python/GPUPasses.cpp | 1 + mlir/lib/Bindings/Python/IRAffine.cpp | 13 +++++++++++++ 7 files changed, 35 insertions(+) diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp index 2b83ed40d..b611a758d 100644 --- a/mlir/lib/Bindings/Python/AsyncPasses.cpp +++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp @@ -8,6 +8,7 @@ #include "mlir-c/Dialect/Async.h" +#include #include // ----------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index 8d0b1014a..8d3f9a7ab 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -9,6 +9,10 @@ #include "mlir-c/Dialect/PDL.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include +#include +#include +#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index de042d1fb..af9cdc7bd 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -9,6 +9,11 @@ #include "mlir-c/Dialect/Quant.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include +#include +#include +#include +#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 9bde3a443..5b5d0136c 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -6,10 +6,16 @@ // //===----------------------------------------------------------------------===// +#include "mlir-c/AffineMap.h" #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include +#include +#include +#include +#include +#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index cbbf8332b..c7764f4e7 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -10,6 +10,11 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include +#include +#include +#include +#include namespace py = pybind11; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp index cb623a11b..e276a3ce3 100644 --- a/mlir/lib/Bindings/Python/GPUPasses.cpp +++ b/mlir/lib/Bindings/Python/GPUPasses.cpp @@ -8,6 +8,7 @@ #include "mlir-c/Dialect/GPU.h" +#include #include // ----------------------------------------------------------------------------- diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 75f86a49e..b138e131e 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,16 +6,29 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include +#include #include +#include #include "IRModule.h" #include "PybindUtils.h" +#include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/IntegerSet.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" namespace py = pybind11; using namespace mlir; From 194de674f063257d4e67575d6f4081b34abd2c7b Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 13 Nov 2023 10:21:21 -0800 Subject: [PATCH 625/915] [mlir][py] Overload print with state. (#72064) Enables reusing the AsmState when printing from Python. Also moves the fileObject and binary to the end (pybind11::object was resulting in the overload not working unless `state=` was specified). --------- Co-authored-by: Maksim Levental --- mlir/lib/Bindings/Python/IRCore.cpp | 47 +++++++++++++++++++++++------ mlir/lib/Bindings/Python/IRModule.h | 9 ++++-- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 745aa64e6..3ddb750bb 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -110,6 +110,15 @@ static const char kOperationPrintDocstring[] = invalid, behavior is undefined. )"; +static const char kOperationPrintStateDocstring[] = + R"(Prints the assembly form of the operation to a file like object. + +Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. +)"; + static const char kOperationGetAsmDocstring[] = R"(Gets the assembly form of the operation with all options available. @@ -1169,11 +1178,11 @@ void PyOperation::checkValid() const { } } -void PyOperationBase::print(py::object fileObject, bool binary, - std::optional largeElementsLimit, +void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified) { + bool assumeVerified, py::object fileObject, + bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) @@ -1198,6 +1207,17 @@ void PyOperationBase::print(py::object fileObject, bool binary, mlirOpPrintingFlagsDestroy(flags); } +void PyOperationBase::print(PyAsmState &state, py::object fileObject, + bool binary) { + PyOperation &operation = getOperation(); + operation.checkValid(); + if (fileObject.is_none()) + fileObject = py::module::import("sys").attr("stdout"); + PyFileAccumulator accum(fileObject, binary); + mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), + accum.getUserData()); +} + void PyOperationBase::writeBytecode(const py::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); @@ -1230,13 +1250,14 @@ py::object PyOperationBase::getAsm(bool binary, } else { fileObject = py::module::import("io").attr("StringIO")(); } - print(fileObject, /*binary=*/binary, - /*largeElementsLimit=*/largeElementsLimit, + print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, /*useLocalScope=*/useLocalScope, - /*assumeVerified=*/assumeVerified); + /*assumeVerified=*/assumeVerified, + /*fileObject=*/fileObject, + /*binary=*/binary); return fileObject.attr("getvalue")(); } @@ -2946,15 +2967,23 @@ void mlir::python::populateIRCore(py::module &m) { /*assumeVerified=*/false); }, "Returns the assembly form of the operation.") - .def("print", &PyOperationBase::print, + .def("print", + py::overload_cast( + &PyOperationBase::print), + py::arg("state"), py::arg("file") = py::none(), + py::arg("binary") = false, kOperationPrintStateDocstring) + .def("print", + py::overload_cast, bool, bool, bool, bool, + bool, py::object, bool>( + &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - py::arg("file") = py::none(), py::arg("binary") = false, py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, kOperationPrintDocstring) + py::arg("assume_verified") = false, py::arg("file") = py::none(), + py::arg("binary") = false, kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index af55693f1..d99b87d19 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -550,16 +550,19 @@ class PyModule : public BaseContextObject { pybind11::handle handle; }; +class PyAsmState; + /// Base class for PyOperation and PyOpView which exposes the primary, user /// visible methods for manipulating it. class PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. - void print(pybind11::object fileObject, bool binary, - std::optional largeElementsLimit, bool enableDebugInfo, + void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified); + bool assumeVerified, py::object fileObject, bool binary); + void print(PyAsmState &state, py::object fileObject, bool binary); + pybind11::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, From 1d6a11baf324b2c47f085ffe63a6b6278445aa35 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 13 Nov 2023 20:25:41 -0600 Subject: [PATCH 626/915] [mlir][python] fix `scf.for_` convenience builder (#72170) --- mlir/python/mlir/dialects/scf.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 71c80cab7..20bbed9bc 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -120,11 +120,13 @@ def for_( params = [start, stop, step] for i, p in enumerate(params): if isinstance(p, int): - p = constant(p) + p = constant(IntegerAttr.get(IndexType.get(), p)) elif isinstance(p, float): raise ValueError(f"{p=} must be int.") params[i] = p + start, stop, step = params + for_op = ForOp(start, stop, step, iter_args, loc=loc, ip=ip) iv = for_op.induction_variable iter_args = tuple(for_op.inner_iter_args) From cfef3fec0b8d59212c45daf66616573bc014f273 Mon Sep 17 00:00:00 2001 From: "long.chen" Date: Tue, 14 Nov 2023 13:01:19 +0800 Subject: [PATCH 627/915] [mlir][affine][nfc] cleanup deprecated T.cast style functions (#71269) detail see the docment: https://mlir.llvm.org/deprecation/ Not all changes are made manually, most of them are made through a clang tool I wrote https://github.com/lipracer/cpp-refactor. --- mlir/lib/CAPI/IR/AffineExpr.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 5b25ab533..6e3328b65 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -66,7 +66,7 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, //===----------------------------------------------------------------------===// bool mlirAffineExprIsADim(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) { @@ -74,7 +74,7 @@ MlirAffineExpr mlirAffineDimExprGet(MlirContext ctx, intptr_t position) { } intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getPosition(); + return cast(unwrap(affineExpr)).getPosition(); } //===----------------------------------------------------------------------===// @@ -82,7 +82,7 @@ intptr_t mlirAffineDimExprGetPosition(MlirAffineExpr affineExpr) { //===----------------------------------------------------------------------===// bool mlirAffineExprIsASymbol(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) { @@ -90,7 +90,7 @@ MlirAffineExpr mlirAffineSymbolExprGet(MlirContext ctx, intptr_t position) { } intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getPosition(); + return cast(unwrap(affineExpr)).getPosition(); } //===----------------------------------------------------------------------===// @@ -98,7 +98,7 @@ intptr_t mlirAffineSymbolExprGetPosition(MlirAffineExpr affineExpr) { //===----------------------------------------------------------------------===// bool mlirAffineExprIsAConstant(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) { @@ -106,7 +106,7 @@ MlirAffineExpr mlirAffineConstantExprGet(MlirContext ctx, int64_t constant) { } int64_t mlirAffineConstantExprGetValue(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).cast().getValue(); + return cast(unwrap(affineExpr)).getValue(); } //===----------------------------------------------------------------------===// @@ -181,13 +181,13 @@ MlirAffineExpr mlirAffineCeilDivExprGet(MlirAffineExpr lhs, //===----------------------------------------------------------------------===// bool mlirAffineExprIsABinary(MlirAffineExpr affineExpr) { - return unwrap(affineExpr).isa(); + return isa(unwrap(affineExpr)); } MlirAffineExpr mlirAffineBinaryOpExprGetLHS(MlirAffineExpr affineExpr) { - return wrap(unwrap(affineExpr).cast().getLHS()); + return wrap(cast(unwrap(affineExpr)).getLHS()); } MlirAffineExpr mlirAffineBinaryOpExprGetRHS(MlirAffineExpr affineExpr) { - return wrap(unwrap(affineExpr).cast().getRHS()); + return wrap(cast(unwrap(affineExpr)).getRHS()); } From 2c95edceb78e45badcfd6f685d8561072b79c45c Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 15 Nov 2023 16:23:01 +0000 Subject: [PATCH 628/915] [mlir][python] NFC - Expose LoopExtensionOps to SCFLoopTransformOps.td --- mlir/python/mlir/dialects/SCFLoopTransformOps.td | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/dialects/SCFLoopTransformOps.td b/mlir/python/mlir/dialects/SCFLoopTransformOps.td index 7b09fc14b..4a904d578 100644 --- a/mlir/python/mlir/dialects/SCFLoopTransformOps.td +++ b/mlir/python/mlir/dialects/SCFLoopTransformOps.td @@ -16,5 +16,6 @@ #define PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.td" +include "mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td" #endif // PYTHON_BINDINGS_SCF_LOOP_TRANSFORM_OPS From 31b3aec1a5077b342ad7a11108608b18d32be5cf Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 16 Nov 2023 13:37:52 -0600 Subject: [PATCH 629/915] [mlir][python] reformat transform ext (#71136) --- .../mlir/dialects/transform/__init__.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 1dca1a66b..7ae4fefba 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -52,28 +52,28 @@ def patterns(self) -> Block: @_ods_cext.register_operation(_Dialect, replace=True) class GetParentOp(GetParentOp): - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - *, - isolated_from_above: bool = False, - op_name: Optional[str] = None, - deduplicate: bool = False, - nth_parent: int = 1, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - isolated_from_above=isolated_from_above, - op_name=op_name, - deduplicate=deduplicate, - nth_parent=nth_parent, - loc=loc, - ip=ip, - ) + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + *, + isolated_from_above: bool = False, + op_name: Optional[str] = None, + deduplicate: bool = False, + nth_parent: int = 1, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + isolated_from_above=isolated_from_above, + op_name=op_name, + deduplicate=deduplicate, + nth_parent=nth_parent, + loc=loc, + ip=ip, + ) @_ods_cext.register_operation(_Dialect, replace=True) From 29cb7cc4a9288f98cf06a4f839bb649fe7b4c973 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:33:05 -0700 Subject: [PATCH 630/915] Apply clang-tidy fixes for performance-unnecessary-value-param in IRAttributes.cpp (NFC) --- mlir/lib/Bindings/Python/IRAttributes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 94fa2527e..dda2003ba 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1032,7 +1032,7 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(py::buffer buffer, std::string name, PyType type, + getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, std::optional alignment, bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { From 7fde7fc53acd25b6e6c82a01d51aa6aa8fe5b8c8 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:34:06 -0700 Subject: [PATCH 631/915] Apply clang-tidy fixes for llvm-else-after-return in IRCore.cpp (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3ddb750bb..fb02b73a7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -981,8 +981,7 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, std::string msg = (Twine("Dialect '") + key + "' not found").str(); if (attrError) throw py::attribute_error(msg); - else - throw py::index_error(msg); + throw py::index_error(msg); } return dialect; } From 9c7fbea00e7b14efb8e74c172c8e66f13e1e2ba6 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:34:51 -0700 Subject: [PATCH 632/915] Apply clang-tidy fixes for misc-include-cleaner in IRCore.cpp (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index fb02b73a7..d75cd8a0c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -17,12 +17,36 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Support/LLVM.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" - +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" + +#include +#include +#include +#include +#include +#include +#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include +#include namespace py = pybind11; using namespace py::literals; From 0b4a51e5fbf4aec329e7853e051f1839f8c8d203 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:35:56 -0700 Subject: [PATCH 633/915] Apply clang-tidy fixes for readability-identifier-naming in IRCore.cpp (NFC) --- mlir/lib/Bindings/Python/IRCore.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d75cd8a0c..4aa8df1bd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3312,9 +3312,9 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly( "ref_operation", [](PyInsertionPoint &self) -> py::object { - auto ref_operation = self.getRefOperation(); - if (ref_operation) - return ref_operation->getObject(); + auto refOperation = self.getRefOperation(); + if (refOperation) + return refOperation->getObject(); return py::none(); }, "The reference operation before which new operations are " From d6ef8d8e1efcaff22d10eb23ca2fd0c272ffe523 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:36:36 -0700 Subject: [PATCH 634/915] Apply clang-tidy fixes for misc-include-cleaner in IRInterfaces.cpp (NFC) --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index dd4190016..136bb1b2b 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,13 +6,23 @@ // //===----------------------------------------------------------------------===// +#include #include +#include +#include +#include +#include +#include #include +#include #include "IRModule.h" #include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/IR.h" #include "mlir-c/Interfaces.h" +#include "mlir-c/Support.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" namespace py = pybind11; From c5e44241785688ab0ee088d9870ba02cecb65e02 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Fri, 20 Oct 2023 01:37:09 -0700 Subject: [PATCH 635/915] Apply clang-tidy fixes for performance-unnecessary-value-param in IRInterfaces.cpp (NFC) --- mlir/lib/Bindings/Python/IRInterfaces.cpp | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 136bb1b2b..c3aac0b09 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -281,8 +281,9 @@ class PyInferTypeOpInterface std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { - llvm::SmallVector mlirOperands = wrapOperands(operandList); - llvm::SmallVector mlirRegions = wrapRegions(regions); + llvm::SmallVector mlirOperands = + wrapOperands(std::move(operandList)); + llvm::SmallVector mlirRegions = wrapRegions(std::move(regions)); std::vector inferredTypes; PyMlirContext &pyContext = context.resolve(); @@ -319,10 +320,10 @@ class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} PyShapedTypeComponents(py::list shape, MlirType elementType) - : shape(shape), elementType(elementType), ranked(true) {} + : shape(std::move(shape)), elementType(elementType), ranked(true) {} PyShapedTypeComponents(py::list shape, MlirType elementType, MlirAttribute attribute) - : shape(shape), elementType(elementType), attribute(attribute), + : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} PyShapedTypeComponents(PyShapedTypeComponents &) = delete; PyShapedTypeComponents(PyShapedTypeComponents &&other) @@ -347,14 +348,15 @@ class PyShapedTypeComponents { .def_static( "get", [](py::list shape, PyType &elementType) { - return PyShapedTypeComponents(shape, elementType); + return PyShapedTypeComponents(std::move(shape), elementType); }, py::arg("shape"), py::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", [](py::list shape, PyType &elementType, PyAttribute &attribute) { - return PyShapedTypeComponents(shape, elementType, attribute); + return PyShapedTypeComponents(std::move(shape), elementType, + attribute); }, py::arg("shape"), py::arg("element_type"), py::arg("attribute"), "Create a ranked shaped type components object with attribute.") @@ -438,8 +440,9 @@ class PyInferShapedTypeOpInterface std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { - llvm::SmallVector mlirOperands = wrapOperands(operandList); - llvm::SmallVector mlirRegions = wrapRegions(regions); + llvm::SmallVector mlirOperands = + wrapOperands(std::move(operandList)); + llvm::SmallVector mlirRegions = wrapRegions(std::move(regions)); std::vector inferredShapedTypeComponents; PyMlirContext &pyContext = context.resolve(); From 0dc32052c03b9ce1c19fb268a83312b8dc81cc77 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 20 Nov 2023 08:56:12 +0000 Subject: [PATCH 636/915] Revert "Apply clang-tidy fixes for misc-include-cleaner in IRCore.cpp (NFC)" This reverts commit 9c7fbea00e7b14efb8e74c172c8e66f13e1e2ba6. Changes make Python bindings unbuildable without additional cmake modifications (or modified `$PATH`). ``` /llvm-project/mlir/lib/Bindings/Python/IRCore.cpp:33:10: fatal error: 'funcobject.h' file not found ``` This header is provided by cpython, and we are not looking for that in cmake. Moreover, the nature of this change is not very clear to me. Seems to replace one include with two dozens, presumably because the code is only using transitively included headers, but the value for readability is dubious. LLVM is also not strictly following IWYU. --- mlir/lib/Bindings/Python/IRCore.cpp | 28 ++-------------------------- 1 file changed, 2 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4aa8df1bd..5612f3b96 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -17,36 +17,12 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Support/LLVM.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" - -#include -#include -#include -#include -#include -#include -#include + #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include -#include namespace py = pybind11; using namespace py::literals; From 6061351617bd8371c564dda29894e637e43935d7 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 20 Nov 2023 19:54:55 -0600 Subject: [PATCH 637/915] [mlir][python] remove eager loading of dialect module (for type and value casting) (#72338) --- mlir/lib/Bindings/Python/IRModule.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 5538924d2..6727860c0 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -132,10 +132,8 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { - // Make sure dialect module is loaded. - if (!loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)))) - return std::nullopt; - + // Try to load dialect module. + (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); @@ -146,7 +144,8 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { - loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + // Try to load dialect module. + (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); From ad0b2a5931fe87fbb475a9d62e086c93b3bee131 Mon Sep 17 00:00:00 2001 From: Edgar Date: Thu, 23 Nov 2023 12:39:54 +0100 Subject: [PATCH 638/915] [mlir] Add mlirTranslateModuleToLLVMIR to MLIR-C (#73117) Fixes #73008 --- mlir/include/mlir-c/Target/LLVMIR.h | 39 +++++++++++++++++++++++++++++ mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Target/CMakeLists.txt | 12 +++++++++ mlir/lib/CAPI/Target/LLVMIR.cpp | 36 ++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir-c/Target/LLVMIR.h create mode 100644 mlir/lib/CAPI/Target/CMakeLists.txt create mode 100644 mlir/lib/CAPI/Target/LLVMIR.cpp diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h new file mode 100644 index 000000000..effa74b90 --- /dev/null +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -0,0 +1,39 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to target LLVMIR with MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_TARGET_LLVMIR_H +#define MLIR_C_TARGET_LLVMIR_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "llvm-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Translate operation that satisfies LLVM dialect module requirements into an +/// LLVM IR module living in the given context. This translates operations from +/// any dilalect that has a registered implementation of +/// LLVMTranslationDialectInterface. +/// +/// \returns the generated LLVM IR Module from the translated MLIR module, it is +/// owned by the caller. +MLIR_CAPI_EXPORTED LLVMModuleRef +mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 707e78ac3..6c4385084 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) +add_subdirectory(Target) if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) @@ -36,4 +37,3 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() - diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt new file mode 100644 index 000000000..ce86fd3de --- /dev/null +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_upstream_c_api_library(MLIRCAPITarget + LLVMIR.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRToLLVMIRTranslationRegistration + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRSupport +) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp new file mode 100644 index 000000000..dc798372b --- /dev/null +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -0,0 +1,36 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/LLVMIR.h" +#include "llvm-c/Support.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +using namespace mlir; + +LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, + LLVMContextRef context) { + Operation *moduleOp = unwrap(module); + + llvm::LLVMContext *ctx = llvm::unwrap(context); + + std::unique_ptr llvmModule = + mlir::translateModuleToLLVMIR(moduleOp, *ctx); + + LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); + + return moduleRef; +} From 68e147be6311b0a66cd10d7181de6d9dce88c663 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 27 Nov 2023 15:58:00 -0600 Subject: [PATCH 639/915] [mlir][python] add type wrappers (#71218) --- mlir/lib/Bindings/Python/IRCore.cpp | 4 +- mlir/lib/Bindings/Python/IRTypes.cpp | 24 ++-- mlir/python/CMakeLists.txt | 1 + mlir/python/mlir/extras/__init__.py | 0 mlir/python/mlir/extras/types.py | 165 +++++++++++++++++++++++++++ 5 files changed, 176 insertions(+), 18 deletions(-) create mode 100644 mlir/python/mlir/extras/__init__.py create mode 100644 mlir/python/mlir/extras/types.py diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5612f3b96..5412c3dec 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2558,8 +2558,8 @@ void mlir::python::populateIRCore(py::module &m) { [](py::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - throw py::value_error("No current Context"); - return context; + return py::none().cast(); + return py::cast(context); }, "Gets the Context bound to the current thread or raises ValueError") .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 483db673f..56e895d30 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -463,7 +463,7 @@ class PyVectorType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_static("get", &PyVectorType::get, py::arg("shape"), - py::arg("elementType"), py::kw_only(), + py::arg("element_type"), py::kw_only(), py::arg("scalable") = py::none(), py::arg("scalable_dims") = py::none(), py::arg("loc") = py::none(), "Create a vector type") @@ -689,13 +689,9 @@ class PyTupleType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_static( "get_tuple", - [](py::list elementList, DefaultingPyMlirContext context) { - intptr_t num = py::len(elementList); - // Mapping py::list to SmallVector. - SmallVector elements; - for (auto element : elementList) - elements.push_back(element.cast()); - MlirType t = mlirTupleTypeGet(context->get(), num, elements.data()); + [](std::vector elements, DefaultingPyMlirContext context) { + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + elements.data()); return PyTupleType(context->getRef(), t); }, py::arg("elements"), py::arg("context") = py::none(), @@ -727,13 +723,11 @@ class PyFunctionType : public PyConcreteType { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::vector inputs, std::vector results, + [](std::vector inputs, std::vector results, DefaultingPyMlirContext context) { - SmallVector inputsRaw(inputs.begin(), inputs.end()); - SmallVector resultsRaw(results.begin(), results.end()); - MlirType t = mlirFunctionTypeGet(context->get(), inputsRaw.size(), - inputsRaw.data(), resultsRaw.size(), - resultsRaw.data()); + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), @@ -742,7 +736,6 @@ class PyFunctionType : public PyConcreteType { "inputs", [](PyFunctionType &self) { MlirType t = self; - auto contextRef = self.getContext(); py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { @@ -754,7 +747,6 @@ class PyFunctionType : public PyConcreteType { c.def_property_readonly( "results", [](PyFunctionType &self) { - auto contextRef = self.getContext(); py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 971ad2dd2..55731943f 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/__init__.py ir.py passmanager.py + extras/types.py dialects/_ods_common.py # The main _mlir module has submodules: include stubs from each. diff --git a/mlir/python/mlir/extras/__init__.py b/mlir/python/mlir/extras/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py new file mode 100644 index 000000000..db9e8229f --- /dev/null +++ b/mlir/python/mlir/extras/types.py @@ -0,0 +1,165 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from functools import partial +from typing import Optional, List + +from ..ir import ( + Attribute, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + Float8E4M3B11FNUZType, + Float8E4M3FNType, + Float8E5M2Type, + FunctionType, + IndexType, + IntegerType, + MemRefType, + NoneType, + OpaqueType, + RankedTensorType, + StridedLayoutAttr, + StringAttr, + TupleType, + Type, + UnrankedMemRefType, + UnrankedTensorType, + VectorType, +) + +index = lambda: IndexType.get() + + +def i(width): + return IntegerType.get_signless(width) + + +def si(width): + return IntegerType.get_signed(width) + + +def ui(width): + return IntegerType.get_unsigned(width) + + +bool = lambda: i(1) +i8 = lambda: i(8) +i16 = lambda: i(16) +i32 = lambda: i(32) +i64 = lambda: i(64) + +si8 = lambda: si(8) +si16 = lambda: si(16) +si32 = lambda: si(32) +si64 = lambda: si(64) + +ui8 = lambda: ui(8) +ui16 = lambda: ui(16) +ui32 = lambda: ui(32) +ui64 = lambda: ui(64) + +f16 = lambda: F16Type.get() +f32 = lambda: F32Type.get() +f64 = lambda: F64Type.get() +bf16 = lambda: BF16Type.get() + +f8E5M2 = lambda: Float8E5M2Type.get() +f8E4M3 = lambda: Float8E4M3FNType.get() +f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() + +none = lambda: NoneType.get() + + +def complex(type): + return ComplexType.get(type) + + +def opaque(dialect_namespace, type_data): + return OpaqueType.get(dialect_namespace, type_data) + + +def _shaped(*shape, element_type: Type = None, type_constructor=None): + if type_constructor is None: + raise ValueError("shaped is an abstract base class - cannot be constructed.") + if (element_type is None and shape and not isinstance(shape[-1], Type)) or ( + shape and isinstance(shape[-1], Type) and element_type is not None + ): + raise ValueError( + f"Either element_type must be provided explicitly XOR last arg to tensor type constructor must be the element type." + ) + if element_type is not None: + type = element_type + sizes = shape + else: + type = shape[-1] + sizes = shape[:-1] + if sizes: + return type_constructor(sizes, type) + else: + return type_constructor(type) + + +def vector( + *shape, + element_type: Type = None, + scalable: Optional[List[bool]] = None, + scalable_dims: Optional[List[int]] = None, +): + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial( + VectorType.get, scalable=scalable, scalable_dims=scalable_dims + ), + ) + + +def tensor(*shape, element_type: Type = None, encoding: Optional[str] = None): + if encoding is not None: + encoding = StringAttr.get(encoding) + if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): + if encoding is not None: + raise ValueError("UnrankedTensorType does not support encoding.") + return _shaped( + *shape, element_type=element_type, type_constructor=UnrankedTensorType.get + ) + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial(RankedTensorType.get, encoding=encoding), + ) + + +def memref( + *shape, + element_type: Type = None, + memory_space: Optional[int] = None, + layout: Optional[StridedLayoutAttr] = None, +): + if memory_space is not None: + memory_space = Attribute.parse(str(memory_space)) + if not shape or (len(shape) == 1 and isinstance(shape[-1], Type)): + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial(UnrankedMemRefType.get, memory_space=memory_space), + ) + return _shaped( + *shape, + element_type=element_type, + type_constructor=partial( + MemRefType.get, memory_space=memory_space, layout=layout + ), + ) + + +def tuple(*elements): + return TupleType.get_tuple(elements) + + +def function(*, inputs, results): + return FunctionType.get(inputs, results) From a5f548b7f4daf7495ddc3b7f1fdc6e71e4b908ac Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Mon, 27 Nov 2023 14:27:52 -0800 Subject: [PATCH 640/915] [mlir][sparse] rename DimLevelType to LevelType (#73561) The "Dim" prefix is a legacy left-over that no longer makes sense, since we have a very strict "Dimension" vs. "Level" definition for sparse tensor types and their storage. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 36 +++++++------- .../Bindings/Python/DialectSparseTensor.cpp | 34 ++++++------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 49 +++++++++---------- 3 files changed, 59 insertions(+), 60 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 859a4f0dd..41d024db0 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -22,24 +22,24 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// Dimension level types (and properties) that define sparse tensors. /// See the documentation in SparseTensorAttrDefs.td for their meaning. /// -/// These correspond to SparseTensorEncodingAttr::DimLevelType in the C++ API. +/// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API. /// If updating, keep them in sync and update the static_assert in the impl /// file. -enum MlirSparseTensorDimLevelType { - MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE = 4, // 0b00001_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED = 8, // 0b00010_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON = 16, // 0b00100_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU = 17, // 0b00100_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO = 18, // 0b00100_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00 - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01 - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10 - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11 - MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 +enum MlirSparseTensorLevelType { + MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10 + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10 + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11 + MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 }; //===----------------------------------------------------------------------===// @@ -53,7 +53,7 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, - enum MlirSparseTensorDimLevelType const *lvlTypes, MlirAffineMap dimToLvl, + enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. @@ -61,7 +61,7 @@ MLIR_CAPI_EXPORTED intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); /// Returns a specified level-type of the `sparse_tensor.encoding` attribute. -MLIR_CAPI_EXPORTED enum MlirSparseTensorDimLevelType +MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding` diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 5b5d0136c..8706c5239 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -23,30 +23,30 @@ using namespace mlir; using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { - py::enum_(m, "DimLevelType", py::module_local()) - .value("dense", MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) - .value("compressed24", MLIR_SPARSE_TENSOR_DIM_LEVEL_TWO_OUT_OF_FOUR) - .value("compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) - .value("compressed_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) - .value("compressed_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) - .value("compressed_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) - .value("singleton", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) - .value("singleton_nu", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) - .value("singleton_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) - .value("singleton_nu_no", MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) - .value("loose_compressed", MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED) + py::enum_(m, "LevelType", py::module_local()) + .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) + .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR) + .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) + .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) + .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) + .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) + .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) + .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) + .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) + .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) + .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) .value("loose_compressed_nu", - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU) + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) .value("loose_compressed_no", - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NO) + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) .value("loose_compressed_nu_no", - MLIR_SPARSE_TENSOR_DIM_LEVEL_LOOSE_COMPRESSED_NU_NO); + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) .def_classmethod( "get", - [](py::object cls, std::vector lvlTypes, + [](py::object cls, std::vector lvlTypes, std::optional dimToLvl, std::optional lvlToDim, int posWidth, int crdWidth, MlirContext context) { @@ -64,7 +64,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { "lvl_types", [](MlirAttribute self) { const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - std::vector ret; + std::vector ret; ret.reserve(lvlRank); for (int l = 0; l < lvlRank; ++l) ret.push_back(mlirSparseTensorEncodingAttrGetLvlType(self, l)); diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index c3ad95527..e4534ad13 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -20,26 +20,25 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, mlir::sparse_tensor::SparseTensorDialect) // Ensure the C-API enums are int-castable to C++ equivalents. -static_assert( - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_DENSE) == - static_cast(DimLevelType::Dense) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED) == - static_cast(DimLevelType::Compressed) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU) == - static_cast(DimLevelType::CompressedNu) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NO) == - static_cast(DimLevelType::CompressedNo) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_COMPRESSED_NU_NO) == - static_cast(DimLevelType::CompressedNuNo) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON) == - static_cast(DimLevelType::Singleton) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU) == - static_cast(DimLevelType::SingletonNu) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NO) == - static_cast(DimLevelType::SingletonNo) && - static_cast(MLIR_SPARSE_TENSOR_DIM_LEVEL_SINGLETON_NU_NO) == - static_cast(DimLevelType::SingletonNuNo), - "MlirSparseTensorDimLevelType (C-API) and DimLevelType (C++) mismatch"); +static_assert(static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == + static_cast(LevelType::Dense) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == + static_cast(LevelType::Compressed) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == + static_cast(LevelType::CompressedNu) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == + static_cast(LevelType::CompressedNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == + static_cast(LevelType::CompressedNuNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == + static_cast(LevelType::Singleton) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == + static_cast(LevelType::SingletonNu) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == + static_cast(LevelType::SingletonNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == + static_cast(LevelType::SingletonNuNo), + "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); @@ -47,13 +46,13 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { MlirAttribute mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, - MlirSparseTensorDimLevelType const *lvlTypes, + MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, int posWidth, int crdWidth) { - SmallVector cppLvlTypes; + SmallVector cppLvlTypes; cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) - cppLvlTypes.push_back(static_cast(lvlTypes[l])); + cppLvlTypes.push_back(static_cast(lvlTypes[l])); return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, crdWidth)); @@ -71,9 +70,9 @@ intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr) { return cast(unwrap(attr)).getLvlRank(); } -MlirSparseTensorDimLevelType +MlirSparseTensorLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { - return static_cast( + return static_cast( cast(unwrap(attr)).getLvlType(lvl)); } From 9356e1748cc76274dbc531c355efe28550f283d4 Mon Sep 17 00:00:00 2001 From: Vitaly Buka Date: Mon, 27 Nov 2023 15:04:53 -0800 Subject: [PATCH 641/915] Revert "[mlir] Add mlirTranslateModuleToLLVMIR to MLIR-C (#73117)" Breaks https://lab.llvm.org/buildbot/#/builders/5/builds/38700 This reverts commit ad0b2a5931fe87fbb475a9d62e086c93b3bee131. --- mlir/include/mlir-c/Target/LLVMIR.h | 39 ----------------------------- mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Target/CMakeLists.txt | 12 --------- mlir/lib/CAPI/Target/LLVMIR.cpp | 36 -------------------------- 4 files changed, 1 insertion(+), 88 deletions(-) delete mode 100644 mlir/include/mlir-c/Target/LLVMIR.h delete mode 100644 mlir/lib/CAPI/Target/CMakeLists.txt delete mode 100644 mlir/lib/CAPI/Target/LLVMIR.cpp diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h deleted file mode 100644 index effa74b90..000000000 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ /dev/null @@ -1,39 +0,0 @@ -//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header declares the C interface to target LLVMIR with MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_TARGET_LLVMIR_H -#define MLIR_C_TARGET_LLVMIR_H - -#include "mlir-c/IR.h" -#include "mlir-c/Support.h" -#include "llvm-c/Support.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/// Translate operation that satisfies LLVM dialect module requirements into an -/// LLVM IR module living in the given context. This translates operations from -/// any dilalect that has a registered implementation of -/// LLVMTranslationDialectInterface. -/// -/// \returns the generated LLVM IR Module from the translated MLIR module, it is -/// owned by the caller. -MLIR_CAPI_EXPORTED LLVMModuleRef -mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 6c4385084..707e78ac3 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -14,7 +14,6 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) -add_subdirectory(Target) if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) @@ -37,3 +36,4 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() + diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt deleted file mode 100644 index ce86fd3de..000000000 --- a/mlir/lib/CAPI/Target/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_mlir_upstream_c_api_library(MLIRCAPITarget - LLVMIR.cpp - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRToLLVMIRTranslationRegistration - MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation - MLIRSupport -) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp deleted file mode 100644 index dc798372b..000000000 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ /dev/null @@ -1,36 +0,0 @@ -//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Target/LLVMIR.h" -#include "llvm-c/Support.h" - -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include - -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" -#include "mlir/CAPI/Wrap.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" - -using namespace mlir; - -LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, - LLVMContextRef context) { - Operation *moduleOp = unwrap(module); - - llvm::LLVMContext *ctx = llvm::unwrap(context); - - std::unique_ptr llvmModule = - mlir::translateModuleToLLVMIR(moduleOp, *ctx); - - LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); - - return moduleRef; -} From 4cdf6d7dcb57624f1714ef32369bfc43bb0ad622 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 27 Nov 2023 19:26:05 -0600 Subject: [PATCH 642/915] [mlir][python] enable registering dialects with the default `Context` (#72488) --- mlir/python/mlir/_mlir_libs/__init__.py | 20 +++++++++++++++++--- mlir/python/mlir/dialects/python_test.py | 4 ++-- mlir/python/mlir/ir.py | 1 + 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 6ce77b4cb..32f46d24c 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -56,6 +56,21 @@ def get_include_dirs() -> Sequence[str]: # # This facility allows downstreams to customize Context creation to their # needs. + +_dialect_registry = None + + +def get_dialect_registry(): + global _dialect_registry + + if _dialect_registry is None: + from ._mlir import ir + + _dialect_registry = ir.DialectRegistry() + + return _dialect_registry + + def _site_initialize(): import importlib import itertools @@ -63,7 +78,6 @@ def _site_initialize(): from ._mlir import ir logger = logging.getLogger(__name__) - registry = ir.DialectRegistry() post_init_hooks = [] disable_multithreading = False @@ -84,7 +98,7 @@ def process_initializer_module(module_name): logger.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): logger.debug("Registering dialects from initializer %r", m) - m.register_dialects(registry) + m.register_dialects(get_dialect_registry()) if hasattr(m, "context_init_hook"): logger.debug("Adding context init hook from %r", m) post_init_hooks.append(m.context_init_hook) @@ -110,7 +124,7 @@ def process_initializer_module(module_name): class Context(ir._BaseContext): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.append_dialect_registry(registry) + self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 6579e02d8..b5baa80bc 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -11,7 +11,7 @@ ) -def register_python_test_dialect(context, load=True): +def register_python_test_dialect(registry): from .._mlir_libs import _mlirPythonTest - _mlirPythonTest.register_python_test_dialect(context, load) + _mlirPythonTest.register_dialect(registry) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 18526ab8c..6d21da3b4 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,6 +5,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster, register_value_caster +from ._mlir_libs import get_dialect_registry # Convenience decorator for registering user-friendly Attribute builders. From 243b9b5b525cb0bfa55ec72fdbe0aef9eda44df8 Mon Sep 17 00:00:00 2001 From: Edgar Date: Wed, 29 Nov 2023 02:37:11 +0100 Subject: [PATCH 643/915] [mlir] Re-Add mlirTranslateModuleToLLVMIR to MLIR-C (#73627) The test was checking something unrelated to what it controlled so it failed after that part changed, i removed that. See https://github.com/llvm/llvm-project/pull/73117 --- mlir/include/mlir-c/Target/LLVMIR.h | 39 +++++++++++++++++++++++++++++ mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Target/CMakeLists.txt | 12 +++++++++ mlir/lib/CAPI/Target/LLVMIR.cpp | 36 ++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir-c/Target/LLVMIR.h create mode 100644 mlir/lib/CAPI/Target/CMakeLists.txt create mode 100644 mlir/lib/CAPI/Target/LLVMIR.cpp diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h new file mode 100644 index 000000000..effa74b90 --- /dev/null +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -0,0 +1,39 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to target LLVMIR with MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_TARGET_LLVMIR_H +#define MLIR_C_TARGET_LLVMIR_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "llvm-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Translate operation that satisfies LLVM dialect module requirements into an +/// LLVM IR module living in the given context. This translates operations from +/// any dilalect that has a registered implementation of +/// LLVMTranslationDialectInterface. +/// +/// \returns the generated LLVM IR Module from the translated MLIR module, it is +/// owned by the caller. +MLIR_CAPI_EXPORTED LLVMModuleRef +mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 707e78ac3..6c4385084 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) +add_subdirectory(Target) if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) @@ -36,4 +37,3 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() - diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt new file mode 100644 index 000000000..ce86fd3de --- /dev/null +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_upstream_c_api_library(MLIRCAPITarget + LLVMIR.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRToLLVMIRTranslationRegistration + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRSupport +) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp new file mode 100644 index 000000000..dc798372b --- /dev/null +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -0,0 +1,36 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/LLVMIR.h" +#include "llvm-c/Support.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +using namespace mlir; + +LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, + LLVMContextRef context) { + Operation *moduleOp = unwrap(module); + + llvm::LLVMContext *ctx = llvm::unwrap(context); + + std::unique_ptr llvmModule = + mlir::translateModuleToLLVMIR(moduleOp, *ctx); + + LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); + + return moduleRef; +} From b67f5856dedcb08140d9e0a67250f92767cddc42 Mon Sep 17 00:00:00 2001 From: Vitaly Buka Date: Tue, 28 Nov 2023 21:05:12 -0800 Subject: [PATCH 644/915] Revert "[mlir] Re-Add mlirTranslateModuleToLLVMIR to MLIR-C (#73627)" (#73749) Still breaks https://lab.llvm.org/buildbot/#/builders/5/builds/38743/steps/9/logs/stdio There is some info on how to reproduce https://github.com/google/sanitizers/wiki/SanitizerBotReproduceBuild This reverts commit 243b9b5b525cb0bfa55ec72fdbe0aef9eda44df8. --- mlir/include/mlir-c/Target/LLVMIR.h | 39 ----------------------------- mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Target/CMakeLists.txt | 12 --------- mlir/lib/CAPI/Target/LLVMIR.cpp | 36 -------------------------- 4 files changed, 1 insertion(+), 88 deletions(-) delete mode 100644 mlir/include/mlir-c/Target/LLVMIR.h delete mode 100644 mlir/lib/CAPI/Target/CMakeLists.txt delete mode 100644 mlir/lib/CAPI/Target/LLVMIR.cpp diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h deleted file mode 100644 index effa74b90..000000000 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ /dev/null @@ -1,39 +0,0 @@ -//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This header declares the C interface to target LLVMIR with MLIR. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_C_TARGET_LLVMIR_H -#define MLIR_C_TARGET_LLVMIR_H - -#include "mlir-c/IR.h" -#include "mlir-c/Support.h" -#include "llvm-c/Support.h" - -#ifdef __cplusplus -extern "C" { -#endif - -/// Translate operation that satisfies LLVM dialect module requirements into an -/// LLVM IR module living in the given context. This translates operations from -/// any dilalect that has a registered implementation of -/// LLVMTranslationDialectInterface. -/// -/// \returns the generated LLVM IR Module from the translated MLIR module, it is -/// owned by the caller. -MLIR_CAPI_EXPORTED LLVMModuleRef -mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); - -#ifdef __cplusplus -} -#endif - -#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 6c4385084..707e78ac3 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -14,7 +14,6 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) -add_subdirectory(Target) if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) @@ -37,3 +36,4 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() + diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt deleted file mode 100644 index ce86fd3de..000000000 --- a/mlir/lib/CAPI/Target/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_mlir_upstream_c_api_library(MLIRCAPITarget - LLVMIR.cpp - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRToLLVMIRTranslationRegistration - MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation - MLIRSupport -) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp deleted file mode 100644 index dc798372b..000000000 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ /dev/null @@ -1,36 +0,0 @@ -//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir-c/Target/LLVMIR.h" -#include "llvm-c/Support.h" - -#include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Module.h" -#include - -#include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" -#include "mlir/CAPI/Wrap.h" -#include "mlir/Target/LLVMIR/ModuleTranslation.h" - -using namespace mlir; - -LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, - LLVMContextRef context) { - Operation *moduleOp = unwrap(module); - - llvm::LLVMContext *ctx = llvm::unwrap(context); - - std::unique_ptr llvmModule = - mlir::translateModuleToLLVMIR(moduleOp, *ctx); - - LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); - - return moduleRef; -} From 4b277f886e9e6bfe612836214d75a5d420769ddb Mon Sep 17 00:00:00 2001 From: Vitaly Buka Date: Wed, 29 Nov 2023 10:59:51 -0800 Subject: [PATCH 645/915] Reapply "[mlir] Add mlirTranslateModuleToLLVMIR to MLIR-C (#73627)" (#73749) (#73751) Co-authored-by: Edgar --- mlir/include/mlir-c/Target/LLVMIR.h | 39 +++++++++++++++++++++++++++++ mlir/lib/CAPI/CMakeLists.txt | 2 +- mlir/lib/CAPI/Target/CMakeLists.txt | 12 +++++++++ mlir/lib/CAPI/Target/LLVMIR.cpp | 36 ++++++++++++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir-c/Target/LLVMIR.h create mode 100644 mlir/lib/CAPI/Target/CMakeLists.txt create mode 100644 mlir/lib/CAPI/Target/LLVMIR.cpp diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h new file mode 100644 index 000000000..effa74b90 --- /dev/null +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -0,0 +1,39 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target -------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface to target LLVMIR with MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_TARGET_LLVMIR_H +#define MLIR_C_TARGET_LLVMIR_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "llvm-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Translate operation that satisfies LLVM dialect module requirements into an +/// LLVM IR module living in the given context. This translates operations from +/// any dilalect that has a registered implementation of +/// LLVMTranslationDialectInterface. +/// +/// \returns the generated LLVM IR Module from the translated MLIR module, it is +/// owned by the caller. +MLIR_CAPI_EXPORTED LLVMModuleRef +mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_TARGET_LLVMIR_H diff --git a/mlir/lib/CAPI/CMakeLists.txt b/mlir/lib/CAPI/CMakeLists.txt index 707e78ac3..6c4385084 100644 --- a/mlir/lib/CAPI/CMakeLists.txt +++ b/mlir/lib/CAPI/CMakeLists.txt @@ -14,6 +14,7 @@ add_subdirectory(Interfaces) add_subdirectory(IR) add_subdirectory(RegisterEverything) add_subdirectory(Transforms) +add_subdirectory(Target) if(MLIR_ENABLE_EXECUTION_ENGINE) add_subdirectory(ExecutionEngine) @@ -36,4 +37,3 @@ if(MLIR_BUILD_MLIR_C_DYLIB) endif() endif() endif() - diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt new file mode 100644 index 000000000..ce86fd3de --- /dev/null +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -0,0 +1,12 @@ +add_mlir_upstream_c_api_library(MLIRCAPITarget + LLVMIR.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRToLLVMIRTranslationRegistration + MLIRCAPIIR + MLIRLLVMToLLVMIRTranslation + MLIRSupport +) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp new file mode 100644 index 000000000..dc798372b --- /dev/null +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -0,0 +1,36 @@ +//===-- LLVMIR.h - C Interface for MLIR LLVMIR Target ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/LLVMIR.h" +#include "llvm-c/Support.h" + +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include + +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +using namespace mlir; + +LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, + LLVMContextRef context) { + Operation *moduleOp = unwrap(module); + + llvm::LLVMContext *ctx = llvm::unwrap(context); + + std::unique_ptr llvmModule = + mlir::translateModuleToLLVMIR(moduleOp, *ctx); + + LLVMModuleRef moduleRef = llvm::wrap(llvmModule.release()); + + return moduleRef; +} From 843c6118a6a0dc058785d38ebf2ab43470adde41 Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Fri, 1 Dec 2023 22:16:00 +0100 Subject: [PATCH 646/915] [mlir][linalg] Fix weight dimension ordering in 2D grouped conv (#73855) The `conv_2d_ngchw_fgchw` Op implements 2d grouped convolution with dimensions ordered as given in the name. However, the current implementation orders weights as `gfchw` instead of `fgchw`. This was already pointed out in an old phabricator revision which never landed: https://reviews.llvm.org/D150064 This patch 1) Adds a new op `conv_2d_ngchw_gfchw` 2) Fixes the dimension ordering of the old op `conv_2d_ngchw_fgchw` 3) Adds tests with non-dynamic dimensions so that it's easier to understand. --- .../linalg/opdsl/ops/core_named_ops.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 62b7da2ae..5b05364f6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -780,7 +780,7 @@ def conv_2d_ngchw_fgchw( T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW ), K=TensorDef(T2, S.FG, S.G, S.C, S.KH, S.KW), - O=TensorDef(U, S.N, S.FG, S.G, S.OH, S.OW, output=True), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), ): @@ -790,6 +790,32 @@ def conv_2d_ngchw_fgchw( * Input: NGCHW. * Kernel: FGCHW. + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) * TypeFn.cast_signed(U, K[D.fg, D.g, D.c, D.kh, D.kw]) + + +@linalg_structured_op +def conv_2d_ngchw_gfchw( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NGCHW. + * Kernel: GFCHW. + Numeric casting is performed on the operands to the inner multiply, promoting them to the same data type as the accumulator/output. """ From 097f73b9b13a046a5d79f309340b73e226b60a2e Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sun, 3 Dec 2023 16:14:09 -0600 Subject: [PATCH 647/915] [mlir][AMDGPU] fix AMDGPU C API registration (#74255) --- mlir/lib/CAPI/Dialect/AMDGPU.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/AMDGPU.cpp b/mlir/lib/CAPI/Dialect/AMDGPU.cpp index 28efe6025..d877ca2df 100644 --- a/mlir/lib/CAPI/Dialect/AMDGPU.cpp +++ b/mlir/lib/CAPI/Dialect/AMDGPU.cpp @@ -10,5 +10,5 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, ml_program, +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(AMDGPU, amdgpu, mlir::amdgpu::AMDGPUDialect) From 43ccedccda4c087638ae4362da4d590c967cde4c Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Tue, 5 Dec 2023 02:33:56 -0500 Subject: [PATCH 648/915] [mlir][python] python binding wrapper for the affine.AffineForOp (#74408) This PR creates the wrapper class AffineForOp and adds a testcase for it. A testcase for the AffineLoadOp is also added. --- mlir/python/mlir/dialects/affine.py | 138 ++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 80d3873e1..26e827009 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -3,3 +3,141 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._affine_ops_gen import * +from ._affine_ops_gen import _Dialect, AffineForOp +from .arith import constant + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + _cext as _ods_cext, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Optional, Sequence, Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineForOp(AffineForOp): + """Specialization for the Affine for op class""" + + def __init__( + self, + lower_bound, + upper_bound, + step, + iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + lower_bound_operands=[], + upper_bound_operands=[], + loc=None, + ip=None, + ): + """Creates an Affine `for` operation. + + - `lower_bound` is the affine map to use as lower bound of the loop. + - `upper_bound` is the affine map to use as upper bound of the loop. + - `step` is the value to use as loop step. + - `iter_args` is a list of additional loop-carried arguments or an operation + producing them as results. + - `lower_bound_operands` is the list of arguments to substitute the dimensions, + then symbols in the `lower_bound` affine map, in an increasing order + - `upper_bound_operands` is the list of arguments to substitute the dimensions, + then symbols in the `upper_bound` affine map, in an increasing order + """ + + if iter_args is None: + iter_args = [] + iter_args = _get_op_results_or_values(iter_args) + if len(lower_bound_operands) != lower_bound.n_inputs: + raise ValueError( + f"Wrong number of lower bound operands passed to AffineForOp. " + + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}." + ) + + if len(upper_bound_operands) != upper_bound.n_inputs: + raise ValueError( + f"Wrong number of upper bound operands passed to AffineForOp. " + + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}." + ) + + results = [arg.type for arg in iter_args] + super().__init__( + results_=results, + lowerBoundOperands=_get_op_results_or_values(lower_bound_operands), + upperBoundOperands=_get_op_results_or_values(upper_bound_operands), + inits=list(iter_args), + lowerBoundMap=AffineMapAttr.get(lower_bound), + upperBoundMap=AffineMapAttr.get(upper_bound), + step=IntegerAttr.get(IndexType.get(), step), + loc=loc, + ip=ip, + ) + self.regions[0].blocks.append(IndexType.get(), *results) + + @property + def body(self): + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def induction_variable(self): + """Returns the induction variable of the loop.""" + return self.body.arguments[0] + + @property + def inner_iter_args(self): + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[1:] + + +def for_( + start, + stop=None, + step=None, + iter_args: Optional[Sequence[Value]] = None, + *, + loc=None, + ip=None, +): + if step is None: + step = 1 + if stop is None: + stop = start + start = 0 + params = [start, stop] + for i, p in enumerate(params): + if isinstance(p, int): + p = constant(IntegerAttr.get(IndexType.get(), p)) + elif isinstance(p, float): + raise ValueError(f"{p=} must be int.") + params[i] = p + + start, stop = params + s0 = AffineSymbolExpr.get(0) + lbmap = AffineMap.get(0, 1, [s0]) + ubmap = AffineMap.get(0, 1, [s0]) + for_op = AffineForOp( + lbmap, + ubmap, + step, + iter_args=iter_args, + lower_bound_operands=[start], + upper_bound_operands=[stop], + loc=loc, + ip=ip, + ) + iv = for_op.induction_variable + iter_args = tuple(for_op.inner_iter_args) + with InsertionPoint(for_op.body): + if len(iter_args) > 1: + yield iv, iter_args + elif len(iter_args) == 1: + yield iv, iter_args[0] + else: + yield iv From 20a0b65e42db1673bccd2346f4913392dfd66723 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 6 Dec 2023 13:42:11 -0500 Subject: [PATCH 649/915] [mlir:python] Fail immediately if importing an initializer module raises ImportError (#74595) --- mlir/python/mlir/_mlir_libs/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 32f46d24c..98dbbc6ad 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -94,6 +94,7 @@ def process_initializer_module(module_name): "encountered otherwise and the MLIR Python API may not function." ) logger.warning(message, exc_info=True) + return False logger.debug("Initializing MLIR with module: %s", module_name) if hasattr(m, "register_dialects"): From 2b330c6db0d4a9bca8a10f2e1da9549ce4263986 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 7 Dec 2023 03:22:45 -0800 Subject: [PATCH 650/915] [mlir] Fix missing cmake dependency causing non-deterministic build failure (NFC) Fixes #74611 --- mlir/python/CMakeLists.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 55731943f..585918afc 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -335,7 +335,8 @@ declare_mlir_dialect_python_bindings( TD_FILE dialects/OpenMPOps.td SOURCES dialects/openmp.py - DIALECT_NAME omp) + DIALECT_NAME omp + DEPENDS omp_common_td) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects From 5ff4e6ba0b59e3ad1a8b05231166d494eff7bb8d Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 7 Dec 2023 10:55:55 -0600 Subject: [PATCH 651/915] [mlir][python] fix up affine for (#74495) --- mlir/python/mlir/dialects/_ods_common.py | 4 + mlir/python/mlir/dialects/affine.py | 105 +++++++++++++---------- 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 60ce83c09..1685124fb 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -134,3 +134,7 @@ def get_op_result_or_op_results( # see the typing.Type doc string. _U = _TypeVar("_U", bound=_cext.ir.Value) SubClassValueT = _Type[_U] + +ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value +ResultValueT = _Union[ResultValueTypeTuple] +VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 26e827009..913cea611 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -3,8 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._affine_ops_gen import * -from ._affine_ops_gen import _Dialect, AffineForOp -from .arith import constant +from ._affine_ops_gen import _Dialect try: from ..ir import * @@ -12,6 +11,9 @@ get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, _cext as _ods_cext, + ResultValueTypeTuple as _ResultValueTypeTuple, + ResultValueT as _ResultValueT, + VariadicResultValueT as _VariadicResultValueT, ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e @@ -21,17 +23,17 @@ @_ods_cext.register_operation(_Dialect, replace=True) class AffineForOp(AffineForOp): - """Specialization for the Affine for op class""" + """Specialization for the Affine for op class.""" def __init__( self, - lower_bound, - upper_bound, - step, - iter_args: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + lower_bound: Union[int, _ResultValueT, AffineMap], + upper_bound: Optional[Union[int, _ResultValueT, AffineMap]], + step: Optional[Union[int, Attribute]] = None, + iter_args: Optional[_ResultValueT] = None, *, - lower_bound_operands=[], - upper_bound_operands=[], + lower_bound_operands: Optional[_VariadicResultValueT] = None, + upper_bound_operands: Optional[_VariadicResultValueT] = None, loc=None, ip=None, ): @@ -43,25 +45,57 @@ def __init__( - `iter_args` is a list of additional loop-carried arguments or an operation producing them as results. - `lower_bound_operands` is the list of arguments to substitute the dimensions, - then symbols in the `lower_bound` affine map, in an increasing order + then symbols in the `lower_bound` affine map, in an increasing order. - `upper_bound_operands` is the list of arguments to substitute the dimensions, - then symbols in the `upper_bound` affine map, in an increasing order + then symbols in the `upper_bound` affine map, in an increasing order. """ + if lower_bound_operands is None: + lower_bound_operands = [] + if upper_bound_operands is None: + upper_bound_operands = [] + + if step is None: + step = 1 + + bounds_operands = [lower_bound_operands, upper_bound_operands] + bounds = [lower_bound, upper_bound] + bounds_names = ["lower", "upper"] + for i, name in enumerate(bounds_names): + if isinstance(bounds[i], int): + bounds[i] = AffineMap.get_constant(bounds[i]) + elif isinstance(bounds[i], _ResultValueTypeTuple): + if len(bounds_operands[i]): + raise ValueError( + f"Either a concrete {name} bound or an AffineMap in combination " + f"with {name} bound operands, but not both, is supported." + ) + if ( + isinstance(bounds[i], (OpView, Operation)) + and len(bounds[i].results) > 1 + ): + raise ValueError( + f"Only a single concrete value is supported for {name} bound." + ) + + bounds_operands[i].append(_get_op_result_or_value(bounds[i])) + bounds[i] = AffineMap.get_identity(1) + + if not isinstance(bounds[i], AffineMap): + raise ValueError( + f"{name} bound must be int | ResultValueT | AffineMap." + ) + if len(bounds_operands[i]) != bounds[i].n_inputs: + raise ValueError( + f"Wrong number of {name} bound operands passed to AffineForOp; " + + f"Expected {bounds[i].n_inputs}, got {len(bounds_operands[i])}." + ) + + lower_bound, upper_bound = bounds + if iter_args is None: iter_args = [] iter_args = _get_op_results_or_values(iter_args) - if len(lower_bound_operands) != lower_bound.n_inputs: - raise ValueError( - f"Wrong number of lower bound operands passed to AffineForOp. " - + "Expected {lower_bound.n_symbols}, got {len(lower_bound_operands)}." - ) - - if len(upper_bound_operands) != upper_bound.n_inputs: - raise ValueError( - f"Wrong number of upper bound operands passed to AffineForOp. " - + "Expected {upper_bound.n_symbols}, got {len(upper_bound_operands)}." - ) results = [arg.type for arg in iter_args] super().__init__( @@ -71,7 +105,7 @@ def __init__( inits=list(iter_args), lowerBoundMap=AffineMapAttr.get(lower_bound), upperBoundMap=AffineMapAttr.get(upper_bound), - step=IntegerAttr.get(IndexType.get(), step), + step=step, loc=loc, ip=ip, ) @@ -98,37 +132,18 @@ def inner_iter_args(self): def for_( start, - stop=None, + stop, step=None, iter_args: Optional[Sequence[Value]] = None, *, loc=None, ip=None, ): - if step is None: - step = 1 - if stop is None: - stop = start - start = 0 - params = [start, stop] - for i, p in enumerate(params): - if isinstance(p, int): - p = constant(IntegerAttr.get(IndexType.get(), p)) - elif isinstance(p, float): - raise ValueError(f"{p=} must be int.") - params[i] = p - - start, stop = params - s0 = AffineSymbolExpr.get(0) - lbmap = AffineMap.get(0, 1, [s0]) - ubmap = AffineMap.get(0, 1, [s0]) for_op = AffineForOp( - lbmap, - ubmap, + start, + stop, step, iter_args=iter_args, - lower_bound_operands=[start], - upper_bound_operands=[stop], loc=loc, ip=ip, ) From db40f7cde370d51b9fddb503f018c93643af946a Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 11 Dec 2023 09:43:08 +0000 Subject: [PATCH 652/915] [mlir][Python] Apply ClangTidy findings. move constructors should be marked noexcept --- mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 2 +- mlir/lib/Bindings/Python/IRInterfaces.cpp | 2 +- mlir/lib/Bindings/Python/IRModule.h | 6 ++++-- mlir/lib/Bindings/Python/Pass.cpp | 3 ++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 3f8342596..b3df30583 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -20,7 +20,7 @@ class PyExecutionEngine { public: PyExecutionEngine(MlirExecutionEngine executionEngine) : executionEngine(executionEngine) {} - PyExecutionEngine(PyExecutionEngine &&other) + PyExecutionEngine(PyExecutionEngine &&other) noexcept : executionEngine(other.executionEngine) { other.executionEngine.ptr = nullptr; } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index c3aac0b09..54cfa5606 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -326,7 +326,7 @@ class PyShapedTypeComponents { : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} PyShapedTypeComponents(PyShapedTypeComponents &) = delete; - PyShapedTypeComponents(PyShapedTypeComponents &&other) + PyShapedTypeComponents(PyShapedTypeComponents &&other) noexcept : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d99b87d19..79b7e0c96 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -4,6 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception //===----------------------------------------------------------------------===// #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H @@ -53,7 +54,7 @@ class PyObjectRef { "cannot construct PyObjectRef with null referrent"); assert(this->object && "cannot construct PyObjectRef with null object"); } - PyObjectRef(PyObjectRef &&other) + PyObjectRef(PyObjectRef &&other) noexcept : referrent(other.referrent), object(std::move(other.object)) { other.referrent = nullptr; assert(!other.object); @@ -484,7 +485,8 @@ class PyDialectRegistry { mlirDialectRegistryDestroy(registry); } PyDialectRegistry(PyDialectRegistry &) = delete; - PyDialectRegistry(PyDialectRegistry &&other) : registry(other.registry) { + PyDialectRegistry(PyDialectRegistry &&other) noexcept + : registry(other.registry) { other.registry = {nullptr}; } diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 588a8e254..a68421b61 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -23,7 +23,8 @@ namespace { class PyPassManager { public: PyPassManager(MlirPassManager passManager) : passManager(passManager) {} - PyPassManager(PyPassManager &&other) : passManager(other.passManager) { + PyPassManager(PyPassManager &&other) noexcept + : passManager(other.passManager) { other.passManager.ptr = nullptr; } ~PyPassManager() { From 73695957e4a54be56cf153b670d5afec3199540e Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Mon, 11 Dec 2023 19:32:21 +0800 Subject: [PATCH 653/915] [mlir][CAPI] Add mlirOpOperandGetValue (#75032) --- mlir/include/mlir-c/IR.h | 3 +++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 413eaa6aa..82da511f8 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -940,6 +940,9 @@ MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, /// Returns whether the op operand is null. MLIR_CAPI_EXPORTED bool mlirOpOperandIsNull(MlirOpOperand opOperand); +/// Returns the value of an op operand. +MLIR_CAPI_EXPORTED MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand); + /// Returns the owner operation of an op operand. MLIR_CAPI_EXPORTED MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index d1ee1b774..ac9889df1 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -986,6 +986,10 @@ MlirOperation mlirOpOperandGetOwner(MlirOpOperand opOperand) { return wrap(unwrap(opOperand)->getOwner()); } +MlirValue mlirOpOperandGetValue(MlirOpOperand opOperand) { + return wrap(unwrap(opOperand)->get()); +} + unsigned mlirOpOperandGetOperandNumber(MlirOpOperand opOperand) { return unwrap(opOperand)->getOperandNumber(); } From 4dadbdced1d709ba873a0695dcc7dce8f03882ac Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 11 Dec 2023 18:35:02 -0600 Subject: [PATCH 654/915] [mlir][python] update type stubs (#75099) --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2530 +++++++++++++++++----- 1 file changed, 2046 insertions(+), 484 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 2609117dd..fa591e5f1 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -1,15 +1,62 @@ # Originally imported via: -# stubgen {...} -m mlir._mlir_libs._mlir.ir +# pybind11-stubgen --print-invalid-expressions-as-is mlir._mlir_libs._mlir.ir +# but with the following diff (in order to remove pipes from types, +# which we won't support until bumping minimum python to 3.10) +# +# --------------------- diff begins ------------------------------------ +# +# diff --git a/pybind11_stubgen/printer.py b/pybind11_stubgen/printer.py +# index 1f755aa..4924927 100644 +# --- a/pybind11_stubgen/printer.py +# +++ b/pybind11_stubgen/printer.py +# @@ -283,14 +283,6 @@ class Printer: +# return split[0] + "..." +# +# def print_type(self, type_: ResolvedType) -> str: +# - if ( +# - str(type_.name) == "typing.Optional" +# - and type_.parameters is not None +# - and len(type_.parameters) == 1 +# - ): +# - return f"{self.print_annotation(type_.parameters[0])} | None" +# - if str(type_.name) == "typing.Union" and type_.parameters is not None: +# - return " | ".join(self.print_annotation(p) for p in type_.parameters) +# if type_.parameters: +# param_str = ( +# "[" +# +# --------------------- diff ends ------------------------------------ +# # Local modifications: -# * Rewrite references to 'mlir.ir.' to local types -# * Add __all__ with the following incantation: -# egrep '^class ' ir.pyi | awk -F ' |:|\\(' '{print " \"" $2 "\","}' -# * Local edits to signatures and types that MyPy did not auto detect (or -# detected incorrectly). - +# * Rewrite references to 'mlir.ir' to local types. +# * Drop `typing.` everywhere (top-level import instead). +# * List -> List, dict -> Dict, Tuple -> Tuple. +# * copy-paste Buffer type from typing_extensions. +# * Shuffle _OperationBase, AffineExpr, Attribute, Type, Value to the top. +# * Patch raw C++ types (like "PyAsmState") with a regex like `Py(.*)`. +# * _BaseContext -> Context, MlirType -> Type, MlirTypeID -> TypeID, MlirAttribute -> Attribute. +# * Local edits to signatures and types that pybind11-stubgen did not auto detect (or detected incorrectly). +# * Add MLIRError, _GlobalDebug, _OperationBase to __all__ by hand. +# * Fill in `Any`s where possible. +# * black formatting. + +from __future__ import annotations + +import abc +import collections +import io from typing import ( - Any, Callable, ClassVar, Dict, List, Optional, Sequence, Tuple, - Type as _Type, TypeVar + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Sequence, + Tuple, + Type as _Type, + TypeVar, + Union, ) from typing import overload @@ -30,6 +77,8 @@ __all__ = [ "AffineSymbolExpr", "ArrayAttr", "ArrayAttributeIterator", + "AsmState", + "AttrBuilder", "Attribute", "BF16Type", "Block", @@ -40,29 +89,44 @@ __all__ = [ "BoolAttr", "ComplexType", "Context", + "DenseBoolArrayAttr", + "DenseBoolArrayIterator", "DenseElementsAttr", + "DenseF32ArrayAttr", + "DenseF32ArrayIterator", + "DenseF64ArrayAttr", + "DenseF64ArrayIterator", "DenseFPElementsAttr", + "DenseI16ArrayAttr", + "DenseI16ArrayIterator", + "DenseI32ArrayAttr", + "DenseI32ArrayIterator", + "DenseI64ArrayAttr", + "DenseI64ArrayIterator", + "DenseI8ArrayAttr", + "DenseI8ArrayIterator", "DenseIntElementsAttr", "DenseResourceElementsAttr", - "Dialect", - "DialectDescriptor", - "Dialects", "Diagnostic", "DiagnosticHandler", "DiagnosticInfo", "DiagnosticSeverity", + "Dialect", + "DialectDescriptor", + "DialectRegistry", + "Dialects", "DictAttr", - "Float8E4M3FNType", - "Float8E5M2Type", - "Float8E4M3FNUZType", - "Float8E4M3B11FNUZType", - "Float8E5M2FNUZType", "F16Type", - "FloatTF32Type", "F32Type", "F64Type", "FlatSymbolRefAttr", + "Float8E4M3B11FNUZType", + "Float8E4M3FNType", + "Float8E4M3FNUZType", + "Float8E5M2FNUZType", + "Float8E5M2Type", "FloatAttr", + "FloatTF32Type", "FunctionType", "IndexType", "InferShapedTypeOpInterface", @@ -76,15 +140,18 @@ __all__ = [ "Location", "MemRefType", "Module", - "MLIRError", "NamedAttribute", "NoneType", - "OpaqueType", "OpAttributeMap", + "OpOperand", + "OpOperandIterator", "OpOperandList", "OpResult", "OpResultList", + "OpSuccessors", "OpView", + "OpaqueAttr", + "OpaqueType", "Operation", "OperationIterator", "OperationList", @@ -94,11 +161,14 @@ __all__ = [ "RegionSequence", "ShapedType", "ShapedTypeComponents", + "StridedLayoutAttr", "StringAttr", + "SymbolRefAttr", "SymbolTable", "TupleType", "Type", "TypeAttr", + "TypeID", "UnitAttr", "UnrankedMemRefType", "UnrankedTensorType", @@ -108,222 +178,561 @@ __all__ = [ "_OperationBase", ] -# Base classes: declared first to simplify declarations below. +if hasattr(collections.abc, "Buffer"): + Buffer = collections.abc.Buffer +else: + class Buffer(abc.ABC): + pass + class _OperationBase: - def detach_from_parent(self) -> OpView: ... - def get_asm(self, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> object: ... - def move_after(self, other: _OperationBase) -> None: ... - def move_before(self, other: _OperationBase) -> None: ... - def print(self, file: Optional[Any] = None, binary: bool = False, large_elements_limit: Optional[int] = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False) -> None: ... - def verify(self) -> bool: ... @overload def __eq__(self, arg0: _OperationBase) -> bool: ... @overload - def __eq__(self, arg0: object) -> bool: ... + def __eq__(self, arg0: _OperationBase) -> bool: ... def __hash__(self) -> int: ... + def __str__(self) -> str: + """ + Returns the assembly form of the operation. + """ + def clone(self, ip: InsertionPoint = None) -> OpView: ... + def detach_from_parent(self) -> OpView: + """ + Detaches the operation from its parent block. + """ + def erase(self) -> None: ... + def get_asm( + self, + binary: bool = False, + large_elements_limit: Optional[int] = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + ) -> Union[io.BytesIO, io.StringIO]: + """ + Gets the assembly form of the operation with all options available. + + Args: + binary: Whether to return a bytes (True) or str (False) object. Defaults to + False. + ... others ...: See the print() method for common keyword arguments for + configuring the printout. + Returns: + Either a bytes or str object, depending on the setting of the 'binary' + argument. + """ + def move_after(self, other: _OperationBase) -> None: + """ + Puts self immediately after the other operation in its parent block. + """ + def move_before(self, other: _OperationBase) -> None: + """ + Puts self immediately before the other operation in its parent block. + """ + @overload + def print( + self, + state: AsmState, + file: Optional[Any] = None, + binary: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + state: AsmState capturing the operation numbering and flags. + """ + @overload + def print( + self, + large_elements_limit: Optional[int] = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + file: Optional[Any] = None, + binary: bool = False, + ) -> None: + """ + Prints the assembly form of the operation to a file like object. + + Args: + file: The file like object to write to. Defaults to sys.stdout. + binary: Whether to write bytes (True) or str (False). Defaults to False. + large_elements_limit: Whether to elide elements attributes above this + number of elements. Defaults to None (no limit). + enable_debug_info: Whether to print debug/location information. Defaults + to False. + pretty_debug_info: Whether to format debug information for easier reading + by a human (warning: the result is unparseable). + print_generic_op_form: Whether to print the generic assembly forms of all + ops. Defaults to False. + use_local_Scope: Whether to print in a way that is more optimized for + multi-threaded access but may not be consistent with how the overall + module prints. + assume_verified: By default, if not printing generic form, the verifier + will be run and if it fails, generic form will be printed with a comment + about failed verification. While a reasonable default for interactive use, + for systematic use, it is often better for the caller to verify explicitly + and report failures in a more robust fashion. Set this to True if doing this + in order to avoid running a redundant verification. If the IR is actually + invalid, behavior is undefined. + """ + def verify(self) -> bool: + """ + Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. + """ + def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None: + """ + Write the bytecode form of the operation to a file like object. + + Args: + file: The file like object to write to. + desired_version: The version of bytecode to emit. + Returns: + The bytecode writer status. + """ @property def _CAPIPtr(self) -> object: ... @property def attributes(self) -> OpAttributeMap: ... @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context that owns the Operation + """ @property - def location(self) -> Location: ... + def location(self) -> Location: + """ + Returns the source location the operation was defined or derived from. + """ @property def name(self) -> str: ... @property def operands(self) -> OpOperandList: ... @property - @property def parent(self) -> Optional[_OperationBase]: ... + @property def regions(self) -> RegionSequence: ... @property - def result(self) -> OpResult: ... + def result(self) -> OpResult: + """ + Shortcut to get an op result if it has only one (throws an error otherwise). + """ @property - def results(self) -> OpResultList: ... + def results(self) -> OpResultList: + """ + Returns the List of Operation results. + """ _TOperation = TypeVar("_TOperation", bound=_OperationBase) -# TODO: Auto-generated. Audit and fix. class AffineExpr: - def __init__(self, *args, **kwargs) -> None: ... + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of two expressions. + """ + @staticmethod + @overload + def get_add(arg0: int, arg1: AffineExpr) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of a constant and another expression. + """ + @staticmethod + @overload + def get_add(arg0: AffineExpr, arg1: int) -> AffineAddExpr: + """ + Gets an affine expression containing a sum of an expression and a constant. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing one expression by another. + """ + @staticmethod + @overload + def get_ceil_div(arg0: int, arg1: AffineExpr) -> AffineCeilDivExpr: + """ + Gets a semi-affine expression containing the rounded-up result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_ceil_div(arg0: AffineExpr, arg1: int) -> AffineCeilDivExpr: + """ + Gets an affine expression containing the rounded-up result of dividing an expression by a constant. + """ + @staticmethod + def get_constant( + value: int, context: Optional[Context] = None + ) -> AffineConstantExpr: + """ + Gets a constant affine expression with the given value. + """ + @staticmethod + def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr: + """ + Gets an affine expression of a dimension at the given position. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing one expression by another. + """ + @staticmethod + @overload + def get_floor_div(arg0: int, arg1: AffineExpr) -> AffineFloorDivExpr: + """ + Gets a semi-affine expression containing the rounded-down result of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_floor_div(arg0: AffineExpr, arg1: int) -> AffineFloorDivExpr: + """ + Gets an affine expression containing the rounded-down result of dividing an expression by a constant. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: + """ + Gets an affine expression containing the modulo of dividing one expression by another. + """ + @staticmethod + @overload + def get_mod(arg0: int, arg1: AffineExpr) -> AffineModExpr: + """ + Gets a semi-affine expression containing the modulo of dividing a constant by an expression. + """ + @staticmethod + @overload + def get_mod(arg0: AffineExpr, arg1: int) -> AffineModExpr: + """ + Gets an affine expression containing the module of dividingan expression by a constant. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of two expressions. + """ + @staticmethod + @overload + def get_mul(arg0: int, arg1: AffineExpr) -> AffineMulExpr: + """ + Gets an affine expression containing a product of a constant and another expression. + """ + @staticmethod + @overload + def get_mul(arg0: AffineExpr, arg1: int) -> AffineMulExpr: + """ + Gets an affine expression containing a product of an expression and a constant. + """ + @staticmethod + def get_symbol( + position: int, context: Optional[Context] = None + ) -> AffineSymbolExpr: + """ + Gets an affine expression of a symbol at the given position. + """ def _CAPICreate(self) -> AffineExpr: ... - def compose(self, arg0) -> AffineExpr: ... - def dump(self) -> None: ... - def get_add(self, *args, **kwargs) -> Any: ... - def get_ceil_div(self, *args, **kwargs) -> Any: ... - def get_constant(self, *args, **kwargs) -> Any: ... - def get_dim(self, *args, **kwargs) -> Any: ... - def get_floor_div(self, *args, **kwargs) -> Any: ... - def get_mod(self, *args, **kwargs) -> Any: ... - def get_mul(self, *args, **kwargs) -> Any: ... - def get_symbol(self, *args, **kwargs) -> Any: ... - def __add__(self, other) -> Any: ... + @overload + def __add__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __add__(self, arg0: int) -> AffineAddExpr: ... @overload def __eq__(self, arg0: AffineExpr) -> bool: ... @overload - def __eq__(self, arg0: object) -> bool: ... + def __eq__(self, arg0: Any) -> bool: ... def __hash__(self) -> int: ... - def __mod__(self, other) -> Any: ... - def __mul__(self, other) -> Any: ... - def __radd__(self, other) -> Any: ... - def __rmod__(self, other) -> Any: ... - def __rmul__(self, other) -> Any: ... - def __rsub__(self, other) -> Any: ... - def __sub__(self, other) -> Any: ... + @overload + def __mod__(self, arg0: AffineExpr) -> AffineModExpr: ... + @overload + def __mod__(self, arg0: int) -> AffineModExpr: ... + @overload + def __mul__(self, arg0: AffineExpr) -> AffineMulExpr: ... + @overload + def __mul__(self, arg0: int) -> AffineMulExpr: ... + def __radd__(self, arg0: int) -> AffineAddExpr: ... + def __repr__(self) -> str: ... + def __rmod__(self, arg0: int) -> AffineModExpr: ... + def __rmul__(self, arg0: int) -> AffineMulExpr: ... + def __rsub__(self, arg0: int) -> AffineAddExpr: ... + def __str__(self) -> str: ... + @overload + def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ... + @overload + def __sub__(self, arg0: int) -> AffineAddExpr: ... + def compose(self, arg0: AffineMap) -> AffineExpr: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ @property def _CAPIPtr(self) -> object: ... @property def context(self) -> Context: ... class Attribute: - def __init__(self, cast_from_type: Attribute) -> None: ... - def _CAPICreate(self) -> Attribute: ... - def dump(self) -> None: ... - def get_named(self, *args, **kwargs) -> Any: ... @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Any: ... + def parse(asm: str, context: Optional[Context] = None) -> Attribute: + """ + Parses an attribute from an assembly form. Raises an MLIRError on failure. + """ + def _CAPICreate(self) -> Attribute: ... @overload def __eq__(self, arg0: Attribute) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Attribute) -> None: + """ + Casts the passed attribute to the generic Attribute + """ + def __repr__(self) -> str: ... + def __str__(self) -> str: + """ + Returns the assembly form of the Attribute. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_named(self, arg0: str) -> NamedAttribute: + """ + Binds a name to the attribute + """ + def maybe_downcast(self) -> Any: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context that owns the Attribute + """ @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... class Type: - def __init__(self, cast_from_type: Type) -> None: ... - def _CAPICreate(self) -> Type: ... - def dump(self) -> None: ... @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Type: ... + def parse(asm: str, context: Optional[Context] = None) -> Type: + """ + Parses the assembly form of a type. + + Returns a Type object or raises an MLIRError if the type cannot be parsed. + + See also: https://mlir.llvm.org/docs/LangRef/#type-system + """ + def _CAPICreate(self) -> Type: ... @overload def __eq__(self, arg0: Type) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... + def __init__(self, cast_from_type: Type) -> None: + """ + Casts the passed type to the generic Type + """ + def __repr__(self) -> str: ... + def __str__(self) -> str: + """ + Returns the assembly form of the type. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def maybe_downcast(self) -> Any: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context that owns the Type + """ + @property + def typeid(self) -> TypeID: ... class Value: def _CAPICreate(self) -> Value: ... - def dump(self) -> None: ... @overload def __eq__(self, arg0: Value) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... + def __init__(self, value: Value) -> None: ... + def __str__(self) -> str: + """ + Returns the string form of the value. + + If the value is a block argument, this is the assembly form of its type and the + position in the argument List. If the value is an operation result, this is + equivalent to printing the operation that produced it. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + @overload + def get_name(self, use_local_scope: bool = False) -> str: ... + @overload + def get_name(self, state: AsmState) -> str: + """ + Returns the string form of value as an operand (i.e., the ValueID). + """ + def maybe_downcast(self) -> Any: ... + def replace_all_uses_with(self, arg0: Value) -> None: + """ + Replace all uses of value with the new value, updating anything in + the IR that uses 'self' to use the other value instead. + """ + def set_type(self, type: Type) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context in which the value lives. + """ @property def owner(self) -> _OperationBase: ... @property def type(self) -> Type: ... + @property + def uses(self) -> OpOperandIterator: ... - -# Classes with no particular order sensitivity in alpha order. -# TODO: Auto-generated. Audit and fix. class AffineAddExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineAddExpr: ... + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineAddExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... -# TODO: Auto-generated. Audit and fix. class AffineBinaryExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... @property def lhs(self) -> AffineExpr: ... @property def rhs(self) -> AffineExpr: ... -# TODO: Auto-generated. Audit and fix. class AffineCeilDivExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineCeilDivExpr: ... + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineCeilDivExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... -# TODO: Auto-generated. Audit and fix. class AffineConstantExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineConstantExpr: ... + def get(value: int, context: Optional[Context] = None) -> AffineConstantExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... @property def value(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class AffineDimExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineDimExpr: ... + def get(position: int, context: Optional[Context] = None) -> AffineDimExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... @property def position(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class AffineExprList: - def __init__(self, *args, **kwargs) -> None: ... def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ... - @overload - def __getitem__(self, arg0: int) -> AffineExpr: ... - @overload - def __getitem__(self, arg0: slice) -> AffineExprList: ... - def __len__(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class AffineFloorDivExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... - def get(*args, **kwargs) -> AffineFloorDivExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineFloorDivExpr: ... + @staticmethod + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... -# TODO: Auto-generated. Audit and fix. class AffineMap: - def __init__(self, *args, **kwargs) -> None: ... - def _CAPICreate(self) -> AffineMap: ... - @staticmethod - def compress_unused_symbols(*args, **kwargs) -> Any: ... - def dump(self) -> None: ... - @staticmethod - def get(*args, **kwargs) -> AffineMap: ... @staticmethod - def get_constant(*args, **kwargs) -> AffineMap: ... - @staticmethod - def get_empty(*args, **kwargs) -> AffineMap: ... - @staticmethod - def get_identity(*args, **kwargs) -> AffineMap: ... - @staticmethod - def get_minor_identity(*args, **kwargs) -> AffineMap: ... - def get_minor_submap(self, n_results: int) -> AffineMap: ... - def get_major_submap(self, n_results: int) -> AffineMap: ... - def get_permutation(self, *args, **kwargs) -> Any: ... - def get_submap(self, result_positions: List[int]) -> AffineMap: ... - def replace(self, expr: AffineExpr, replacement: AffineExpr, n_result_dims: int, n_result_syms: int) -> AffineMap: ... + def compress_unused_symbols( + arg0: List, arg1: Optional[Context] + ) -> List[AffineMap]: ... + @staticmethod + def get( + dim_count: int, + symbol_count: int, + exprs: List, + context: Optional[Context] = None, + ) -> AffineMap: + """ + Gets a map with the given expressions as results. + """ + @staticmethod + def get_constant(value: int, context: Optional[Context] = None) -> AffineMap: + """ + Gets an affine map with a single constant result + """ + @staticmethod + def get_empty(context: Optional[Context] = None) -> AffineMap: + """ + Gets an empty affine map. + """ + @staticmethod + def get_identity(n_dims: int, context: Optional[Context] = None) -> AffineMap: + """ + Gets an identity map with the given number of dimensions. + """ + @staticmethod + def get_minor_identity( + n_dims: int, n_results: int, context: Optional[Context] = None + ) -> AffineMap: + """ + Gets a minor identity map with the given number of dimensions and results. + """ + @staticmethod + def get_permutation( + permutation: List[int], context: Optional[Context] = None + ) -> AffineMap: + """ + Gets an affine map that permutes its inputs. + """ + def _CAPICreate(self) -> AffineMap: ... @overload def __eq__(self, arg0: AffineMap) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_major_submap(self, n_results: int) -> AffineMap: ... + def get_minor_submap(self, n_results: int) -> AffineMap: ... + def get_submap(self, result_positions: List[int]) -> AffineMap: ... + def replace( + self, + expr: AffineExpr, + replacement: AffineExpr, + n_result_dims: int, + n_result_syms: int, + ) -> AffineMap: ... @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context that owns the Affine Map + """ @property def is_permutation(self) -> bool: ... @property @@ -335,460 +744,1119 @@ class AffineMap: @property def n_symbols(self) -> int: ... @property - def results(self) -> Any: ... + def results(self) -> "AffineMapExprList": ... -# TODO: Auto-generated. Audit and fix. class AffineMapAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> AffineMapAttr: ... + def get(affine_map: AffineMap) -> AffineMapAttr: + """ + Gets an attribute wrapping an AffineMap. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class AffineModExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineModExpr: ... + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineModExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... -# TODO: Auto-generated. Audit and fix. class AffineMulExpr(AffineBinaryExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineMulExpr: ... + def get(arg0: AffineExpr, arg1: AffineExpr) -> AffineMulExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... -# TODO: Auto-generated. Audit and fix. class AffineSymbolExpr(AffineExpr): - def __init__(self, expr: AffineExpr) -> None: ... @staticmethod - def get(*args, **kwargs) -> AffineSymbolExpr: ... + def get(position: int, context: Optional[Context] = None) -> AffineSymbolExpr: ... @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: AffineExpr) -> bool: ... + def __init__(self, expr: AffineExpr) -> None: ... @property def position(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class ArrayAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> ArrayAttr: ... + def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr: + """ + Gets a uniqued Array attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - def __add__(self, arg0: list) -> ArrayAttr: ... + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> ArrayAttr: ... def __getitem__(self, arg0: int) -> Attribute: ... - def __iter__(self) -> Any: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> ArrayAttributeIterator: ... def __len__(self) -> int: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class ArrayAttributeIterator: - def __init__(self, *args, **kwargs) -> None: ... def __iter__(self) -> ArrayAttributeIterator: ... def __next__(self) -> Attribute: ... -# TODO: Auto-generated. Audit and fix. +class AsmState: + @overload + def __init__(self, value: Value, use_local_scope: bool = False) -> None: ... + @overload + def __init__(self, op: _OperationBase, use_local_scope: bool = False) -> None: ... + +class AttrBuilder: + @staticmethod + def contains(arg0: str) -> bool: ... + @staticmethod + def get(arg0: str) -> Callable: ... + @staticmethod + def insert( + attribute_kind: str, attr_builder: Callable, replace: bool = False + ) -> None: + """ + Register an attribute builder for building MLIR attributes from python values. + """ + class BF16Type(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> BF16Type: ... + def get(context: Optional[Context] = None) -> BF16Type: + """ + Create a bf16 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... class Block: - __hash__: ClassVar[None] = ... # type: ignore - def append(self, operation: _OperationBase) -> None: ... - def create_after(self, *args: Type) -> Block: ... @staticmethod - def create_at_start(parent: Region, arg_types: List[Type]) -> Block: ... - def create_before(self, *args: Type) -> Block: ... + def create_at_start( + parent: Region, + arg_types: List[Type], + arg_locs: Optional[Sequence] = None, + ) -> Block: + """ + Creates and returns a new Block at the beginning of the given region (with given argument types and locations). + """ @overload def __eq__(self, arg0: Block) -> bool: ... @overload - def __eq__(self, arg0: object) -> bool: ... - def __iter__(self) -> Any: ... + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + def __iter__(self) -> OperationIterator: + """ + Iterates over operations in the block. + """ + def __str__(self) -> str: + """ + Returns the assembly form of the block. + """ + def append(self, operation: _OperationBase) -> None: + """ + Appends an operation to this block. If the operation is currently in another block, it will be moved. + """ + def append_to(self, arg0: Region) -> None: + """ + Append this block to a region, transferring ownership if necessary + """ + def create_after(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + """ + Creates and returns a new Block after this block (with given argument types and locations). + """ + def create_before(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + """ + Creates and returns a new Block before this block (with given argument types and locations). + """ + @property + def _CAPIPtr(self) -> object: ... @property - def arguments(self) -> BlockArgumentList: ... + def arguments(self) -> BlockArgumentList: + """ + Returns a List of block arguments. + """ @property - def operations(self) -> OperationList: ... + def operations(self) -> OperationList: + """ + Returns a forward-optimized sequence of operations. + """ @property - def owner(self) -> OpView: ... + def owner(self) -> OpView: + """ + Returns the owning operation of this block. + """ @property - def region(self) -> Region: ... + def region(self) -> Region: + """ + Returns the owning region of this block. + """ class BlockArgument(Value): @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other_value: Value) -> bool: ... + def __init__(self, value: Value) -> None: ... + def maybe_downcast(self) -> Any: ... def set_type(self, type: Type) -> None: ... @property def arg_number(self) -> int: ... @property - def owner(self) -> Block: ... # type: ignore[override] + def owner(self) -> Block: ... class BlockArgumentList: def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... - @overload - def __getitem__(self, arg0: int) -> BlockArgument: ... - @overload - def __getitem__(self, arg0: slice) -> BlockArgumentList: ... - def __len__(self) -> int: ... @property def types(self) -> List[Type]: ... class BlockIterator: - def __init__(self, *args, **kwargs) -> None: ... def __iter__(self) -> BlockIterator: ... def __next__(self) -> Block: ... class BlockList: - def append(self, *args) -> Block: ... def __getitem__(self, arg0: int) -> Block: ... def __iter__(self) -> BlockIterator: ... def __len__(self) -> int: ... + def append(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + """ + Appends a new block, with argument types as positional args. + + Returns: + The created block. + """ -# TODO: Auto-generated. Audit and fix. class BoolAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... @staticmethod - def get(*args, **kwargs) -> BoolAttr: ... + def get(value: bool, context: Optional[Context] = None) -> BoolAttr: + """ + Gets an uniqued bool attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __bool__(self: Attribute) -> bool: + """ + Converts the value of the bool attribute to a Python bool + """ + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... @property def type(self) -> Type: ... @property - def value(self) -> bool: ... + def typeid(self) -> TypeID: ... + @property + def value(self) -> bool: + """ + Returns the value of the bool attribute + """ -# TODO: Auto-generated. Audit and fix. class ComplexType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> ComplexType: ... + def get(arg0: Type) -> ComplexType: + """ + Create a complex type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def element_type(self) -> Type: + """ + Returns element type. + """ @property - def element_type(self) -> Type: ... + def typeid(self) -> TypeID: ... class Context: current: ClassVar[Context] = ... # read-only allow_unregistered_dialects: bool - def __init__(self) -> None: ... - def _CAPICreate(self) -> object: ... - def _get_context_again(self) -> Context: ... @staticmethod def _get_live_count() -> int: ... + def _CAPICreate(self) -> object: ... + def __enter__(self) -> Any: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... + def __init__(self) -> None: ... + def _clear_live_operations(self) -> int: ... + def _get_context_again(self) -> Context: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... - def attach_diagnostic_handler(self, callback: Callable[[Diagnostic], bool]) -> DiagnosticHandler: ... + def append_dialect_registry(self, registry: DialectRegistry) -> None: ... + def attach_diagnostic_handler( + self, callback: Callable[[Diagnostic], bool] + ) -> DiagnosticHandler: + """ + Attaches a diagnostic handler that will receive callbacks + """ def enable_multithreading(self, enable: bool) -> None: ... - def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: ... + def get_dialect_descriptor(self, dialect_name: str) -> DialectDescriptor: + """ + Gets or loads a dialect by name, returning its descriptor object + """ def is_registered_operation(self, operation_name: str) -> bool: ... - def __enter__(self) -> Context: ... - def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def load_all_available_dialects(self) -> None: ... @property def _CAPIPtr(self) -> object: ... @property - def d(self) -> Dialects: ... + def d(self) -> Dialects: + """ + Alias for 'dialect' + """ @property - def dialects(self) -> Dialects: ... - def append_dialect_registry(self, registry: "DialectRegistry") -> None: ... - def load_all_available_dialects(self) -> None: ... - -class DialectRegistry: - def __init__(self) -> None: ... + def dialects(self) -> Dialects: + """ + Gets a container for accessing dialects by name + """ -# TODO: Auto-generated. Audit and fix. -class DenseElementsAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... +class DenseBoolArrayAttr(Attribute): @staticmethod - def get(*args, **kwargs) -> DenseElementsAttr: ... + def get( + values: List[bool], context: Optional[Context] = None + ) -> DenseBoolArrayAttr: + """ + Gets a uniqued dense array attribute + """ @staticmethod - def get_splat(*args, **kwargs) -> Any: ... + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseBoolArrayAttr: ... + def __getitem__(self, arg0: int) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseBoolArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseBoolArrayIterator: + def __iter__(self) -> DenseBoolArrayIterator: ... + def __next__(self) -> bool: ... + +class DenseElementsAttr(Attribute): @staticmethod - def isinstance(arg: Any) -> bool: ... + def get( + array: Buffer, + signless: bool = True, + type: Optional[Type] = None, + shape: Optional[List[int]] = None, + context: Optional[Context] = None, + ) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr from a Python buffer or array. + + When `type` is not provided, then some limited type inferencing is done based + on the buffer format. Support presently exists for 8/16/32/64 signed and + unsigned integers and float16/float32/float64. DenseElementsAttrs of these + types can also be converted back to a corresponding buffer. + + For conversions outside of these types, a `type=` must be explicitly provided + and the buffer contents must be bit-castable to the MLIR internal + representation: + + * Integer types (except for i1): the buffer must be byte aligned to the + next byte boundary. + * Floating point types: Must be bit-castable to the given floating point + size. + * i1 (bool): Bit packed into 8bit words where the bit pattern matches a + row major ordering. An arbitrary Numpy `bool_` array can be bit packed to + this specification with: `np.packbits(ary, axis=None, bitorder='little')`. + + If a single element buffer is passed (or for i1, a single byte with value 0 + or 255), then a splat will be created. + + Args: + array: The array or buffer to convert. + signless: If inferring an appropriate MLIR type, use signless types for + integers (defaults True). + type: Skips inference of the MLIR element type and uses this instead. The + storage size must be consistent with the actual contents of the buffer. + shape: Overrides the shape of the buffer when constructing the MLIR + shaped type. This is needed when the physical and logical shape differ (as + for i1). + context: Explicit context, if not from context manager. + + Returns: + DenseElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def get_splat(shaped_type: Type, element_attr: Attribute) -> DenseElementsAttr: + """ + Gets a DenseElementsAttr where all values are the same + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... def __len__(self) -> int: ... + def __repr__(self) -> str: ... + def get_splat_value(self) -> Attribute: ... @property def is_splat(self) -> bool: ... @property + def static_typeid(self) -> TypeID: ... + @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. -class DenseFPElementsAttr(DenseElementsAttr): - def __init__(self, cast_from_attr: Attribute) -> None: ... +class DenseF32ArrayAttr(Attribute): @staticmethod - def get(*args, **kwargs) -> DenseFPElementsAttr: ... + def get( + values: List[float], context: Optional[Context] = None + ) -> DenseF32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseF32ArrayAttr: ... def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF32ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. -class DenseIntElementsAttr(DenseElementsAttr): - def __init__(self, cast_from_attr: Attribute) -> None: ... +class DenseF32ArrayIterator: + def __iter__(self) -> DenseF32ArrayIterator: ... + def __next__(self) -> float: ... + +class DenseF64ArrayAttr(Attribute): @staticmethod - def get(*args, **kwargs) -> DenseIntElementsAttr: ... + def get( + values: List[float], context: Optional[Context] = None + ) -> DenseF64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - def __getitem__(self, arg0: int) -> int: ... + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseF64ArrayAttr: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseF64ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -class DenseResourceElementsAttr(Attribute): - @staticmethod - def get_from_buffer(array: Any, name: str, type: Type, alignment: Optional[int] = None, is_mutable: bool = False, context: Optional[Context] = None) -> None: ... +class DenseF64ArrayIterator: + def __iter__(self) -> DenseF64ArrayIterator: ... + def __next__(self) -> float: ... -class Dialect: - def __init__(self, descriptor: DialectDescriptor) -> None: ... +class DenseFPElementsAttr(DenseElementsAttr): + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> float: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... @property - def descriptor(self) -> DialectDescriptor: ... + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -class DialectDescriptor: +class DenseI16ArrayAttr(Attribute): + @staticmethod + def get(values: List[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseI16ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... @property - def namespace(self) -> str: ... + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -class Dialects: - def __init__(self, *args, **kwargs) -> None: ... - def __getattr__(self, arg0: str) -> Dialect: ... - def __getitem__(self, arg0: str) -> Dialect: ... +class DenseI16ArrayIterator: + def __iter__(self) -> DenseI16ArrayIterator: ... + def __next__(self) -> int: ... -class Diagnostic: - @property - def severity(self) -> DiagnosticSeverity: ... +class DenseI32ArrayAttr(Attribute): + @staticmethod + def get(values: List[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseI32ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI32ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... @property - def location(self) -> Location: ... + def static_typeid(self) -> TypeID: ... @property - def message(self) -> str: ... + def type(self) -> Type: ... @property - def notes(self) -> Tuple[Diagnostic]: ... + def typeid(self) -> TypeID: ... + +class DenseI32ArrayIterator: + def __iter__(self) -> DenseI32ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI64ArrayAttr(Attribute): + @staticmethod + def get(values: List[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseI64ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI16ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI64ArrayIterator: + def __iter__(self) -> DenseI64ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseI8ArrayAttr(Attribute): + @staticmethod + def get(values: List[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: + """ + Gets a uniqued dense array attribute + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __add__(self, arg0: List) -> DenseI8ArrayAttr: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __iter__( + self, + ) -> DenseI8ArrayIterator: ... + def __len__(self) -> int: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseI8ArrayIterator: + def __iter__(self) -> DenseI8ArrayIterator: ... + def __next__(self) -> int: ... + +class DenseIntElementsAttr(DenseElementsAttr): + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __getitem__(self, arg0: int) -> int: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class DenseResourceElementsAttr(Attribute): + @staticmethod + def get_from_buffer( + array: Buffer, + name: str, + type: Type, + alignment: Optional[int] = None, + is_mutable: bool = False, + context: Optional[Context] = None, + ) -> DenseResourceElementsAttr: + """ + Gets a DenseResourceElementsAttr from a Python buffer or array. + + This function does minimal validation or massaging of the data, and it is + up to the caller to ensure that the buffer meets the characteristics + implied by the shape. + + The backing buffer and any user objects will be retained for the lifetime + of the resource blob. This is typically bounded to the context but the + resource can have a shorter lifespan depending on how it is used in + subsequent processing. + + Args: + buffer: The array or buffer to convert. + name: Name to provide to the resource (may be changed upon collision). + type: The explicit ShapedType to construct the attribute with. + context: Explicit context, if not from context manager. + + Returns: + DenseResourceElementsAttr on success. + + Raises: + ValueError: If the type of the buffer or array cannot be matched to an MLIR + type or if the buffer does not meet expectations. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class Diagnostic: + def __str__(self) -> str: ... + @property + def location(self) -> Location: ... + @property + def message(self) -> str: ... + @property + def notes(self) -> Tuple[Diagnostic]: ... + @property + def severity(self) -> DiagnosticSeverity: ... class DiagnosticHandler: + def __enter__(self) -> DiagnosticHandler: ... + def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... def detach(self) -> None: ... @property def attached(self) -> bool: ... @property def had_error(self) -> bool: ... - def __enter__(self) -> DiagnosticHandler: ... - def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... class DiagnosticInfo: - def __init__(self, diag: Diagnostic) -> None: ... - @property - def severity(self) -> "DiagnosticSeverity": ... + def __init__(self, arg0: Diagnostic) -> None: ... + def __str__(self) -> str: ... @property - def location(self) -> "Location": ... + def location(self) -> Location: ... @property def message(self) -> str: ... @property - def notes(self) -> Sequence["DiagnosticInfo"]: ... + def notes(self) -> List[DiagnosticInfo]: ... + @property + def severity(self) -> DiagnosticSeverity: ... class DiagnosticSeverity: - ERROR: DiagnosticSeverity - WARNING: DiagnosticSeverity - NOTE: DiagnosticSeverity - REMARK: DiagnosticSeverity + """ + Members: + + ERROR + + WARNING + + NOTE + + REMARK + """ + + ERROR: ClassVar[DiagnosticSeverity] # value = + NOTE: ClassVar[DiagnosticSeverity] # value = + REMARK: ClassVar[DiagnosticSeverity] # value = + WARNING: ClassVar[DiagnosticSeverity] # value = + __members__: ClassVar[ + Dict[str, DiagnosticSeverity] + ] # value = {'ERROR': , 'WARNING': , 'NOTE': , 'REMARK': } + def __eq__(self, other: Any) -> bool: ... + def __getstate__(self) -> int: ... + def __hash__(self) -> int: ... + def __index__(self) -> int: ... + def __init__(self, value: int) -> None: ... + def __int__(self) -> int: ... + def __ne__(self, other: Any) -> bool: ... + def __repr__(self) -> str: ... + def __setstate__(self, state: int) -> None: ... + def __str__(self) -> str: ... + @property + def name(self) -> str: ... + @property + def value(self) -> int: ... + +class Dialect: + def __init__(self, descriptor: DialectDescriptor) -> None: ... + def __repr__(self) -> Any: ... + @property + def descriptor(self) -> DialectDescriptor: ... + +class DialectDescriptor: + def __repr__(self) -> str: ... + @property + def namespace(self) -> str: ... + +class DialectRegistry: + def _CAPICreate(self) -> DialectRegistry: ... + def __init__(self) -> None: ... + @property + def _CAPIPtr(self) -> object: ... + +class Dialects: + def __getattr__(self, arg0: str) -> Any: ... + def __getitem__(self, arg0: str) -> Any: ... -# TODO: Auto-generated. Audit and fix. class DictAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> DictAttr: ... + def get(value: Dict = {}, context: Optional[Context] = None) -> DictAttr: + """ + Gets an uniqued Dict attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... def __contains__(self, arg0: str) -> bool: ... @overload def __getitem__(self, arg0: str) -> Attribute: ... @overload def __getitem__(self, arg0: int) -> NamedAttribute: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... def __len__(self) -> int: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -class Float8E4M3FNType(Type): - def __init__(self, cast_from_type: Type) -> None: ... +class F16Type(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Float8E4M3FNType: ... + def get(context: Optional[Context] = None) -> F16Type: + """ + Create a f16 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -class Float8E5M2Type(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class F32Type(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Float8E5M2Type: ... + def get(context: Optional[Context] = None) -> F32Type: + """ + Create a f32 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -class Float8E4M3FNUZType(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class F64Type(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Float8E4M3FNUZType: ... + def get(context: Optional[Context] = None) -> F64Type: + """ + Create a f64 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -class Float8E4M3B11FNUZType(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class FlatSymbolRefAttr(Attribute): @staticmethod - def get(*args, **kwargs) -> Float8E4M3B11FNUZType: ... + def get(value: str, context: Optional[Context] = None) -> FlatSymbolRefAttr: + """ + Gets a uniqued FlatSymbolRef attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + @property + def value(self) -> str: + """ + Returns the value of the FlatSymbolRef attribute as a string + """ -class Float8E5M2FNUZType(Type): - def __init__(self, cast_from_type: Type) -> None: ... +class Float8E4M3B11FNUZType(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Float8E5M2FNUZType: ... + def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: + """ + Create a float8_e4m3b11fnuz type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -# TODO: Auto-generated. Audit and fix. -class F16Type(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNType(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> F16Type: ... + def get(context: Optional[Context] = None) -> Float8E4M3FNType: + """ + Create a float8_e4m3fn type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -# TODO: Auto-generated. Audit and fix. -class FloatTF32Type(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E4M3FNUZType(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> FloatTF32Type: ... + def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: + """ + Create a float8_e4m3fnuz type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -# TODO: Auto-generated. Audit and fix. -class F32Type(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2FNUZType(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> F32Type: ... + def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: + """ + Create a float8_e5m2fnuz type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -# TODO: Auto-generated. Audit and fix. -class F64Type(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... + +class Float8E5M2Type(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> F64Type: ... + def get(context: Optional[Context] = None) -> Float8E5M2Type: + """ + Create a float8_e5m2 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. -class FlatSymbolRefAttr(Attribute): +class FloatAttr(Attribute): + static_typeid: ClassVar[TypeID] # value = + @staticmethod + def get(type: Type, value: float, loc: Optional[Location] = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a type + """ + @staticmethod + def get_f32(value: float, context: Optional[Context] = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f32 type + """ + @staticmethod + def get_f64(value: float, context: Optional[Context] = None) -> FloatAttr: + """ + Gets an uniqued float point attribute associated to a f64 type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __float__(self: Attribute) -> float: + """ + Converts the value of the float attribute to a Python float + """ def __init__(self, cast_from_attr: Attribute) -> None: ... - @staticmethod - def get(*args, **kwargs) -> FlatSymbolRefAttr: ... - @staticmethod - def isinstance(arg: Any) -> bool: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property - def value(self) -> str: ... + def typeid(self) -> TypeID: ... + @property + def value(self) -> float: + """ + Returns the value of the float attribute + """ -# TODO: Auto-generated. Audit and fix. -class FloatAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... - @staticmethod - def get(*args, **kwargs) -> FloatAttr: ... - @staticmethod - def get_f32(*args, **kwargs) -> FloatAttr: ... +class FloatTF32Type(Type): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get_f64(*args, **kwargs) -> FloatAttr: ... + def get(context: Optional[Context] = None) -> FloatTF32Type: + """ + Create a tf32 type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - @property - def type(self) -> Type: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... @property - def value(self) -> float: ... + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class FunctionType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> FunctionType: ... + def get( + inputs: List[Type], results: List[Type], context: Optional[Context] = None + ) -> FunctionType: + """ + Gets a FunctionType from a List of input and result types + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def inputs(self) -> List: + """ + Returns the List of input types in the FunctionType. + """ @property - def inputs(self) -> list: ... + def results(self) -> List: + """ + Returns the List of result types in the FunctionType. + """ @property - def results(self) -> list: ... + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class IndexType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> IndexType: ... + def get(context: Optional[Context] = None) -> IndexType: + """ + Create a index type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... class InferShapedTypeOpInterface: - def __init__(self, object: object, context: Optional[Context] = None) -> None: ... - def inferReturnTypeComponents(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[ShapedTypeComponents]: ... - @property - def operation(self) -> Operation: ... - @property - def opview(self) -> OpView: ... + def __init__(self, object: object, context: Optional[Context] = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypeComponents( + self, + operands: Optional[List] = None, + attributes: Optional[Attribute] = None, + properties=None, + regions: Optional[List[Region]] = None, + context: Optional[Context] = None, + loc: Optional[Location] = None, + ) -> List[ShapedTypeComponents]: + """ + Given the arguments required to build an operation, attempts to infer + its return shaped type components. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ class InferTypeOpInterface: - def __init__(self, object: object, context: Optional[Context] = None) -> None: ... - def inferReturnTypes(self, operands: Optional[List] = None, attributes: Optional[Attribute] = None, properties = None, regions: Optional[List[Region]] = None, context: Optional[Context] = None, loc: Optional[Location] = None) -> List[Type]: ... - @property - def operation(self) -> Operation: ... - @property - def opview(self) -> OpView: ... + def __init__(self, object: object, context: Optional[Context] = None) -> None: + """ + Creates an interface from a given operation/opview object or from a + subclass of OpView. Raises ValueError if the operation does not implement the + interface. + """ + def inferReturnTypes( + self, + operands: Optional[List] = None, + attributes: Optional[Attribute] = None, + properties=None, + regions: Optional[List[Region]] = None, + context: Optional[Context] = None, + loc: Optional[Location] = None, + ) -> List[Type]: + """ + Given the arguments required to build an operation, attempts to infer + its return types. Raises ValueError on failure. + """ + @property + def operation(self) -> Operation: + """ + Returns an Operation for which the interface was constructed. + """ + @property + def opview(self) -> OpView: + """ + Returns an OpView subclass _instance_ for which the interface was + constructed + """ class InsertionPoint: current: ClassVar[InsertionPoint] = ... # read-only + @staticmethod + def at_block_begin(block: Block) -> InsertionPoint: + """ + Inserts at the beginning of the block. + """ + @staticmethod + def at_block_terminator(block: Block) -> InsertionPoint: + """ + Inserts before the block terminator. + """ + def __enter__(self) -> Any: ... + def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... @overload - def __init__(self, block: Block) -> None: ... + def __init__(self, block: Block) -> None: + """ + Inserts after the last operation but still inside the block. + """ @overload - def __init__(self, beforeOperation: _OperationBase) -> None: ... - @staticmethod - def at_block_begin(block: Block) -> InsertionPoint: ... - @staticmethod - def at_block_terminator(block: Block) -> InsertionPoint: ... - def insert(self, operation: _OperationBase) -> None: ... - def __enter__(self) -> InsertionPoint: ... - def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... - @property - def block(self) -> Block: ... - @property - def ref_operation(self) -> Optional[_OperationBase]: ... + def __init__(self, beforeOperation: _OperationBase) -> None: + """ + Inserts before a referenced operation. + """ + def insert(self, operation: _OperationBase) -> None: + """ + Inserts an operation. + """ + @property + def block(self) -> Block: + """ + Returns the block that this InsertionPoint points to. + """ + @property + def ref_operation(self) -> Optional[_OperationBase]: + """ + The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block + """ -# TODO: Auto-generated. Audit and fix. class IntegerAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> IntegerAttr: ... + def get(type: Type, value: int) -> IntegerAttr: + """ + Gets an uniqued integer attribute associated to a type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __int__(self) -> int: + """ + Converts the value of the integer attribute to a Python int + """ + def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property - def value(self) -> int: ... + def typeid(self) -> TypeID: ... + @property + def value(self) -> int: + """ + Returns the value of the integer attribute + """ -# TODO: Auto-generated. Audit and fix. class IntegerSet: - def __init__(self, *args, **kwargs) -> None: ... - def _CAPICreate(self) -> IntegerSet: ... - def dump(self) -> None: ... @staticmethod - def get(*args, **kwargs) -> IntegerSet: ... - @staticmethod - def get_empty(*args, **kwargs) -> IntegerSet: ... - def get_replaced(self, dim_exprs: list, symbol_exprs: list, num_result_dims: int, num_result_symbols: int) -> IntegerSet: ... + def get( + num_dims: int, + num_symbols: int, + exprs: List, + eq_flags: List[bool], + context: Optional[Context] = None, + ) -> IntegerSet: ... + @staticmethod + def get_empty( + num_dims: int, num_symbols: int, context: Optional[Context] = None + ) -> IntegerSet: ... + def _CAPICreate(self) -> IntegerSet: ... @overload def __eq__(self, arg0: IntegerSet) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... + def __repr__(self) -> str: ... + def __str__(self) -> str: ... + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ + def get_replaced( + self, + dim_exprs: List, + symbol_exprs: List, + num_result_dims: int, + num_result_symbols: int, + ) -> IntegerSet: ... @property def _CAPIPtr(self) -> object: ... @property - def constraints(self) -> Any: ... + def constraints(self) -> IntegerSetConstraintList: ... @property def context(self) -> Context: ... @property @@ -804,7 +1872,6 @@ class IntegerSet: @property def n_symbols(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class IntegerSetConstraint: def __init__(self, *args, **kwargs) -> None: ... @property @@ -812,7 +1879,6 @@ class IntegerSetConstraint: @property def is_eq(self) -> bool: ... -# TODO: Auto-generated. Audit and fix. class IntegerSetConstraintList: def __init__(self, *args, **kwargs) -> None: ... def __add__(self, arg0: IntegerSetConstraintList) -> List[IntegerSetConstraint]: ... @@ -822,108 +1888,232 @@ class IntegerSetConstraintList: def __getitem__(self, arg0: slice) -> IntegerSetConstraintList: ... def __len__(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class IntegerType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get_signed(*args, **kwargs) -> IntegerType: ... + def get_signed(width: int, context: Optional[Context] = None) -> IntegerType: + """ + Create a signed integer type + """ @staticmethod - def get_signless(*args, **kwargs) -> IntegerType: ... + def get_signless(width: int, context: Optional[Context] = None) -> IntegerType: + """ + Create a signless integer type + """ @staticmethod - def get_unsigned(*args, **kwargs) -> IntegerType: ... + def get_unsigned(width: int, context: Optional[Context] = None) -> IntegerType: + """ + Create an unsigned integer type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def is_signed(self) -> bool: + """ + Returns whether this is a signed integer + """ @property - def is_signed(self) -> bool: ... + def is_signless(self) -> bool: + """ + Returns whether this is a signless integer + """ @property - def is_signless(self) -> bool: ... + def is_unsigned(self) -> bool: + """ + Returns whether this is an unsigned integer + """ @property - def is_unsigned(self) -> bool: ... + def typeid(self) -> TypeID: ... @property - def width(self) -> int: ... + def width(self) -> int: + """ + Returns the width of the integer type + """ class Location: current: ClassVar[Location] = ... # read-only - __hash__: ClassVar[None] = ... # type: ignore + __hash__: ClassVar[None] = None + @staticmethod + def callsite( + callee: Location, frames: Sequence[Location], context: Optional[Context] = None + ) -> Location: + """ + Gets a Location representing a caller and callsite + """ + @staticmethod + def file( + filename: str, line: int, col: int, context: Optional[Context] = None + ) -> Location: + """ + Gets a Location representing a file, line and column + """ + @staticmethod + def from_attr(attribute: Attribute, context: Optional[Context] = None) -> Location: + """ + Gets a Location from a LocationAttr + """ + @staticmethod + def fused( + locations: Sequence[Location], + metadata: Optional[Attribute] = None, + context: Optional[Context] = None, + ) -> Location: + """ + Gets a Location representing a fused location with optional metadata + """ + @staticmethod + def name( + name: str, + childLoc: Optional[Location] = None, + context: Optional[Context] = None, + ) -> Location: + """ + Gets a Location representing a named location with optional child location + """ + @staticmethod + def unknown(context: Optional[Context] = None) -> Location: + """ + Gets a Location representing an unknown location + """ def _CAPICreate(self) -> Location: ... - @staticmethod - def callsite(callee: Location, frames: Sequence[Location], context: Optional[Context] = None) -> Location: ... - @staticmethod - def file(filename: str, line: int, col: int, context: Optional[Context] = None) -> Location: ... - @staticmethod - def fused(locations: Sequence[Location], metadata: Optional[Attribute] = None, context: Optional[Context] = None) -> Location: ... - @staticmethod - def name(name: str, childLoc: Optional[Location] = None, context: Optional[Context] = None) -> Location: ... - @staticmethod - def unknown(context: Optional[Context] = None) -> Any: ... def __enter__(self) -> Location: ... @overload def __eq__(self, arg0: Location) -> bool: ... @overload - def __eq__(self, arg0: object) -> bool: ... + def __eq__(self, arg0: Location) -> bool: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... + def __repr__(self) -> str: ... + def emit_error(self, message: str) -> None: + """ + Emits an error at this location + """ @property def _CAPIPtr(self) -> object: ... @property - def context(self) -> Context: ... + def attr(self) -> Attribute: + """ + Get the underlying LocationAttr + """ + @property + def context(self) -> Context: + """ + Context that owns the Location + """ -# TODO: Auto-generated. Audit and fix. class MemRefType(ShapedType): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> MemRefType: ... + def get( + shape: List[int], + element_type: Type, + layout: Attribute = None, + memory_space: Attribute = None, + loc: Optional[Location] = None, + ) -> MemRefType: + """ + Create a memref type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def affine_map(self) -> AffineMap: + """ + The layout of the MemRef type as an affine map. + """ @property - def affine_map(self) -> AffineMap: ... + def layout(self) -> Attribute: + """ + The layout of the MemRef type. + """ @property - def layout(self) -> Attribute: ... + def memory_space(self) -> Optional[Attribute]: + """ + Returns the memory space of the given MemRef type. + """ @property - def memory_space(self) -> Attribute: ... + def typeid(self) -> TypeID: ... class Module: - def _CAPICreate(self) -> object: ... @staticmethod - def create(loc: Optional[Location] = None) -> Module: ... - def dump(self) -> None: ... - @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Module: ... + def create(loc: Optional[Location] = None) -> Any: + """ + Creates an empty module + """ + @staticmethod + def parse(asm: str, context: Optional[Context] = None) -> Any: + """ + Parses a module's assembly format from a string. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + + See also: https://mlir.llvm.org/docs/LangRef/ + """ + def _CAPICreate(self) -> Any: ... + def __str__(self) -> Any: + """ + Gets the assembly form of the operation with default options. + + If more advanced control over the assembly formatting or I/O options is needed, + use the dedicated print or get_asm method, which supports keyword arguments to + customize behavior. + """ + def dump(self) -> None: + """ + Dumps a debug representation of the object to stderr. + """ @property def _CAPIPtr(self) -> object: ... @property - def body(self) -> Block: ... + def body(self) -> Block: + """ + Return the block for this module + """ @property - def context(self) -> Context: ... + def context(self) -> Context: + """ + Context that created the Module + """ @property - def operation(self) -> Operation: ... + def operation(self) -> Any: + """ + Accesses the module as an operation + """ class MLIRError(Exception): - def __init__(self, message: str, error_diagnostics: List[DiagnosticInfo]) -> None: ... + def __init__( + self, message: str, error_diagnostics: List[DiagnosticInfo] + ) -> None: ... class NamedAttribute: + def __repr__(self) -> str: ... @property - def attr(self) -> Attribute: ... + def attr(self) -> Attribute: + """ + The underlying generic attribute of the NamedAttribute binding + """ @property - def name(self) -> str: ... + def name(self) -> str: + """ + The name of the NamedAttribute binding + """ -# TODO: Auto-generated. Audit and fix. class NoneType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> NoneType: ... + def get(context: Optional[Context] = None) -> NoneType: + """ + Create a none type. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... - -class OpaqueType(Type): + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - @staticmethod - def get(*args, **kwargs) -> OpaqueType: ... - @staticmethod - def isinstance(arg: Any) -> bool: ... + def __repr__(self) -> str: ... @property - def dialect_namespace(self) -> str: ... - @property - def data(self) -> str: ... + def typeid(self) -> TypeID: ... class OpAttributeMap: def __contains__(self, arg0: str) -> bool: ... @@ -935,6 +2125,16 @@ class OpAttributeMap: def __len__(self) -> int: ... def __setitem__(self, arg0: str, arg1: Attribute) -> None: ... +class OpOperand: + @property + def operand_number(self) -> int: ... + @property + def owner(self) -> _OperationBase: ... + +class OpOperandIterator: + def __iter__(self) -> OpOperandIterator: ... + def __next__(self) -> OpOperand: ... + class OpOperandList: def __add__(self, arg0: OpOperandList) -> List[Value]: ... @overload @@ -945,6 +2145,8 @@ class OpOperandList: def __setitem__(self, arg0: int, arg1: Value) -> None: ... class OpResult(Value): + @staticmethod + def isinstance(other_value: Value) -> bool: ... def __init__(self, value: Value) -> None: ... @staticmethod def isinstance(arg: Any) -> bool: ... @@ -961,8 +2163,14 @@ class OpResultList: def __getitem__(self, arg0: slice) -> OpResultList: ... def __len__(self) -> int: ... @property + def owner(self) -> _OperationBase: ... + @property def types(self) -> List[Type]: ... +class OpSuccessors: + def __add__(self, arg0: OpSuccessors) -> List[Block]: ... + def __setitem__(self, arg0: int, arg1: Block) -> None: ... + class OpView(_OperationBase): _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... _ODS_REGIONS: ClassVar[tuple] = ... @@ -977,60 +2185,193 @@ class OpView(_OperationBase): successors: Optional[Sequence[Block]] = None, regions: Optional[int] = None, loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None) -> _TOperation: ... + ip: Optional[InsertionPoint] = None, + ) -> _TOperation: + """ + Builds a specific, generated OpView based on class level attributes. + """ + @classmethod + def parse( + cls: _Type[_TOperation], + source: str, + *, + source_name: str = "", + context: Optional[Context] = None, + ) -> _TOperation: + """ + Parses a specific, generated OpView based on class level attributes + """ + def __init__(self, operation: _OperationBase) -> None: ... + def __str__(self) -> str: ... @property - def operation(self) -> Operation: ... + def operation(self) -> _OperationBase: ... + @property + def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ + +class OpaqueAttr(Attribute): + static_typeid: ClassVar[TypeID] # value = + @staticmethod + def get( + dialect_namespace: str, + buffer: Buffer, + type: Type, + context: Optional[Context] = None, + ) -> OpaqueAttr: + """ + Gets an Opaque attribute. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def data(self) -> bytes: + """ + Returns the data for the Opaqued attributes as `bytes` + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque attribute as a string + """ + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + +class OpaqueType(Type): + static_typeid: ClassVar[TypeID] # value = + @staticmethod + def get( + dialect_namespace: str, buffer: str, context: Optional[Context] = None + ) -> OpaqueType: + """ + Create an unregistered (opaque) dialect type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... @property - def opview(self) -> "OpView": ... + def data(self) -> str: + """ + Returns the data for the Opaque type as a string. + """ + @property + def dialect_namespace(self) -> str: + """ + Returns the dialect namespace for the Opaque type as a string. + """ + @property + def typeid(self) -> TypeID: ... class Operation(_OperationBase): def _CAPICreate(self) -> object: ... @staticmethod - def create(name: str, results: Optional[Sequence[Type]] = None, + def create( + name: str, + results: Optional[Sequence[Type]] = None, operands: Optional[Sequence[Value]] = None, attributes: Optional[Dict[str, Attribute]] = None, successors: Optional[Sequence[Block]] = None, regions: int = 0, loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None) -> _OperationBase: ... - def erase(self) -> None: ... + ip: Optional[InsertionPoint] = None, + infer_type: bool = False, + ) -> Operation: + """ + Creates a new operation. + + Args: + name: Operation name (e.g. "dialect.operation"). + results: Sequence of Type representing op result types. + attributes: Dict of str:Attribute. + successors: List of Block for the operation's successors. + regions: Number of regions to create. + location: A Location object (defaults to resolve from context manager). + ip: An InsertionPoint (defaults to resolve from context manager or set to + False to disable insertion, even with an insertion point set in the + context manager). + infer_type: Whether to infer result types. + Returns: + A new "detached" Operation object. Detached operations can be added + to blocks, which causes them to become "attached." + """ + @staticmethod + def parse( + source: str, *, source_name: str = "", context: Optional[Context] = None + ) -> Any: + """ + Parses an operation. Supports both text assembly format and binary bytecode format. + """ + def _CAPICreate(self) -> object: ... @property def _CAPIPtr(self) -> object: ... @property - def operation(self) -> "Operation": ... + def operation(self) -> Operation: ... @property def opview(self) -> OpView: ... + @property + def successors(self) -> OpSuccessors: + """ + Returns the List of Operation successors. + """ class OperationIterator: def __iter__(self) -> OperationIterator: ... def __next__(self) -> OpView: ... class OperationList: - def __getitem__(self, arg0: int) -> OpView: ... + def __getitem__(self, arg0: int) -> Any: ... def __iter__(self) -> OperationIterator: ... def __len__(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class RankedTensorType(ShapedType): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> RankedTensorType: ... + def get( + shape: List[int], + element_type: Type, + encoding: Optional[Attribute] = None, + loc: Optional[Location] = None, + ) -> RankedTensorType: + """ + Create a ranked tensor type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... @property def encoding(self) -> Optional[Attribute]: ... + @property + def typeid(self) -> TypeID: ... class Region: - __hash__: ClassVar[None] = ... # type: ignore + __hash__: ClassVar[None] = None @overload def __eq__(self, arg0: Region) -> bool: ... @overload def __eq__(self, arg0: object) -> bool: ... - def __iter__(self) -> BlockIterator: ... + def __iter__(self) -> BlockIterator: + """ + Iterates over blocks in the region. + """ @property - def blocks(self) -> BlockList: ... + def blocks(self) -> BlockList: + """ + Returns a forward-optimized sequence of blocks. + """ @property - def owner(self) -> OpView: ... + def owner(self) -> OpView: + """ + Returns the operation owning this region. + """ class RegionIterator: def __iter__(self) -> RegionIterator: ... @@ -1038,132 +2379,353 @@ class RegionIterator: class RegionSequence: def __getitem__(self, arg0: int) -> Region: ... + def __iter__(self) -> RegionIterator: ... def __len__(self) -> int: ... -# TODO: Auto-generated. Audit and fix. class ShapedType(Type): + @staticmethod + def get_dynamic_size() -> int: + """ + Returns the value used to indicate dynamic dimensions in shaped types. + """ + @staticmethod + def get_dynamic_stride_or_offset() -> int: + """ + Returns the value used to indicate dynamic strides or offsets in shaped types. + """ + @staticmethod + def is_dynamic_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a dynamic dimension. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def get_dim_size(self, dim: int) -> int: ... - def is_dynamic_dim(self, dim: int) -> bool: ... - def is_dynamic_size(self, *args, **kwargs) -> Any: ... - def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: ... + def __repr__(self) -> str: ... + def get_dim_size(self, dim: int) -> int: + """ + Returns the dim-th dimension of the given ranked shaped type. + """ + def is_dynamic_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is dynamic. + """ + def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type is ranked. + """ + @property + def has_static_shape(self) -> bool: + """ + Returns whether the given shaped type has a static shape. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type. + """ + @property + def shape(self) -> List[int]: + """ + Returns the shape of the ranked shaped type as a List of integers. + """ + @property + def static_typeid(self) -> TypeID: ... + @property + def typeid(self) -> TypeID: ... + +class ShapedTypeComponents: @staticmethod - def isinstance(arg: Any) -> bool: ... - @property - def element_type(self) -> Type: ... + @overload + def get(element_type: Type) -> ShapedTypeComponents: + """ + Create an shaped type components object with only the element type. + """ + @staticmethod + @overload + def get(shape: List, element_type: Type) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object. + """ + @staticmethod + @overload + def get( + shape: List, element_type: Type, attribute: Attribute + ) -> ShapedTypeComponents: + """ + Create a ranked shaped type components object with attribute. + """ + @property + def element_type(self) -> Type: + """ + Returns the element type of the shaped type components. + """ + @property + def has_rank(self) -> bool: + """ + Returns whether the given shaped type component is ranked. + """ + @property + def rank(self) -> int: + """ + Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned. + """ + @property + def shape(self) -> List[int]: + """ + Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank. + """ + +class StridedLayoutAttr(Attribute): + static_typeid: ClassVar[TypeID] # value = + @staticmethod + def get( + offset: int, strides: List[int], context: Optional[Context] = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute. + """ + @staticmethod + def get_fully_dynamic( + rank: int, context: Optional[Context] = None + ) -> StridedLayoutAttr: + """ + Gets a strided layout attribute with dynamic offset and strides of a given rank. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... @property - def has_rank(self) -> bool: ... + def offset(self) -> int: + """ + Returns the value of the float point attribute + """ @property - def has_static_shape(self) -> bool: ... + def strides(self) -> List[int]: + """ + Returns the value of the float point attribute + """ @property - def rank(self) -> int: ... + def type(self) -> Type: ... @property - def shape(self) -> List[int]: ... + def typeid(self) -> TypeID: ... -class ShapedTypeComponents: - @property - def element_type(self) -> Type: ... +class StringAttr(Attribute): + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> ShapedTypeComponents: ... + def get(value: str, context: Optional[Context] = None) -> StringAttr: + """ + Gets a uniqued string attribute + """ + @staticmethod + def get_typed(type: Type, value: str) -> StringAttr: + """ + Gets a uniqued string attribute associated to a type + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def type(self) -> Type: ... @property - def has_rank(self) -> bool: ... + def typeid(self) -> TypeID: ... @property - def rank(self) -> int: ... + def value(self) -> str: + """ + Returns the value of the string attribute + """ @property - def shape(self) -> List[int]: ... + def value_bytes(self) -> bytes: + """ + Returns the value of the string attribute as `bytes` + """ -# TODO: Auto-generated. Audit and fix. -class StringAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... - @staticmethod - def get(*args, **kwargs) -> Any: ... +class SymbolRefAttr(Attribute): @staticmethod - def get_typed(*args, **kwargs) -> Any: ... + def get(symbols: List[str], context: Optional[Context] = None) -> Attribute: + """ + Gets a uniqued SymbolRef attribute from a List of symbol names + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... + @property + def static_typeid(self) -> TypeID: ... @property def type(self) -> Type: ... @property - def value(self) -> str: ... + def typeid(self) -> TypeID: ... + @property + def value(self) -> List[str]: + """ + Returns the value of the SymbolRef attribute as a List[str] + """ class SymbolTable: - def __init__(self, arg0: _OperationBase) -> None: ... - def erase(self, operation: _OperationBase) -> None: ... @staticmethod def get_symbol_name(symbol: _OperationBase) -> Attribute: ... @staticmethod def get_visibility(symbol: _OperationBase) -> Attribute: ... - def insert(self, operation: _OperationBase) -> Attribute: ... @staticmethod - def replace_all_symbol_uses(old_symbol: str, new_symbol: str, from_op: _OperationBase) -> None: ... + def replace_all_symbol_uses( + old_symbol: str, new_symbol: str, from_op: _OperationBase + ) -> None: ... @staticmethod def set_symbol_name(symbol: _OperationBase, name: str) -> None: ... @staticmethod def set_visibility(symbol: _OperationBase, visibility: str) -> None: ... @staticmethod - def walk_symbol_tables(from_op: _OperationBase, all_sym_uses_visible: bool, callback: Callable[[_OperationBase, bool], None]) -> None: ... + def walk_symbol_tables( + from_op: _OperationBase, + all_sym_uses_visible: bool, + callback: Callable[[_OperationBase, bool], None], + ) -> None: ... def __contains__(self, arg0: str) -> bool: ... def __delitem__(self, arg0: str) -> None: ... def __getitem__(self, arg0: str) -> OpView: ... + def __init__(self, arg0: _OperationBase) -> None: ... + def erase(self, operation: _OperationBase) -> None: ... + def insert(self, operation: _OperationBase) -> Attribute: ... -# TODO: Auto-generated. Audit and fix. class TupleType(Type): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get_tuple(*args, **kwargs) -> TupleType: ... - def get_type(self, pos: int) -> Type: ... + def get_Tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: + """ + Create a Tuple type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + def get_type(self, pos: int) -> Type: + """ + Returns the pos-th type in the Tuple type. + """ @property - def num_types(self) -> int: ... + def num_types(self) -> int: + """ + Returns the number of types contained in a Tuple. + """ + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class TypeAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Any: ... + def get(value: Type, context: Optional[Context] = None) -> TypeAttr: + """ + Gets a uniqued Type attribute + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property + def typeid(self) -> TypeID: ... + @property def value(self) -> Type: ... -# TODO: Auto-generated. Audit and fix. +class TypeID: + def _CAPICreate(self) -> TypeID: ... + @overload + def __eq__(self, arg0: TypeID) -> bool: ... + @overload + def __eq__(self, arg0: Any) -> bool: ... + def __hash__(self) -> int: ... + @property + def _CAPIPtr(self) -> object: ... + class UnitAttr(Attribute): - def __init__(self, cast_from_attr: Attribute) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> Any: ... + def get(context: Optional[Context] = None) -> UnitAttr: + """ + Create a Unit attribute. + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + def __repr__(self) -> str: ... @property def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class UnrankedMemRefType(ShapedType): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> UnrankedMemRefType: ... + def get( + element_type: Type, memory_space: Attribute, loc: Optional[Location] = None + ) -> UnrankedMemRefType: + """ + Create a unranked memref type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def memory_space(self) -> Optional[Attribute]: + """ + Returns the memory space of the given Unranked MemRef type. + """ @property - def memory_space(self) -> Attribute: ... + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class UnrankedTensorType(ShapedType): - def __init__(self, cast_from_type: Type) -> None: ... + static_typeid: ClassVar[TypeID] # value = @staticmethod - def get(*args, **kwargs) -> UnrankedTensorType: ... + def get(element_type: Type, loc: Optional[Location] = None) -> UnrankedTensorType: + """ + Create a unranked tensor type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def typeid(self) -> TypeID: ... -# TODO: Auto-generated. Audit and fix. class VectorType(ShapedType): - def __init__(self, cast_from_type: Type) -> None: ... - @staticmethod - def get(*args, **kwargs) -> VectorType: ... + static_typeid: ClassVar[TypeID] # value = + @staticmethod + def get( + shape: List[int], + element_type: Type, + *, + scalable: Optional[List] = None, + scalable_dims: Optional[List[int]] = None, + loc: Optional[Location] = None, + ) -> VectorType: + """ + Create a vector type + """ @staticmethod - def isinstance(arg: Any) -> bool: ... + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + def __repr__(self) -> str: ... + @property + def scalable(self) -> bool: ... + @property + def scalable_dims(self) -> List[bool]: ... + @property + def typeid(self) -> TypeID: ... class _GlobalDebug: - flag: ClassVar[bool] = ... + flag: ClassVar[bool] = False From 5ef50f24ea3034803bf644207bd4da45890eef7e Mon Sep 17 00:00:00 2001 From: martin-luecke Date: Fri, 15 Dec 2023 13:04:43 +0100 Subject: [PATCH 655/915] [MLIR][transform][python] add sugared python abstractions for transform dialect (#75073) This adds Python abstractions for the different handle types of the transform dialect The abstractions allow for straightforward chaining of transforms by calling their member functions. As an initial PR for this infrastructure, only a single transform is included: `transform.structured.match`. With a future `tile` transform abstraction an example of the usage is: ```Python def script(module: OpHandle): module.match_ops(MatchInterfaceEnum.TilingInterface).tile(tile_sizes=[32,32]) ``` to generate the following IR: ```mlir %0 = transform.structured.match interface{TilingInterface} in %arg0 %tiled_op, %loops = transform.structured.tile_using_for %0 [32, 32] ``` These abstractions are intended to enhance the usability and flexibility of the transform dialect by providing an accessible interface that allows for easy assembly of complex transformation chains. --- mlir/include/mlir-c/Dialect/Transform.h | 8 + .../mlir/Bindings/Python/PybindAdaptors.h | 7 + mlir/lib/Bindings/Python/DialectTransform.cpp | 12 +- mlir/lib/CAPI/Dialect/Transform.cpp | 18 ++- mlir/python/CMakeLists.txt | 8 + .../extras/dialects/transform/__init__.py | 148 ++++++++++++++++++ 6 files changed, 196 insertions(+), 5 deletions(-) create mode 100644 mlir/python/mlir/extras/dialects/transform/__init__.py diff --git a/mlir/include/mlir-c/Dialect/Transform.h b/mlir/include/mlir-c/Dialect/Transform.h index 91c99b1f8..02c99b592 100644 --- a/mlir/include/mlir-c/Dialect/Transform.h +++ b/mlir/include/mlir-c/Dialect/Transform.h @@ -25,6 +25,8 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Transform, transform); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyOpType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyOpTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -33,6 +35,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyOpTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyParamTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -41,6 +45,8 @@ MLIR_CAPI_EXPORTED MlirType mlirTransformAnyParamTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformAnyValueType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformAnyValueTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformAnyValueTypeGet(MlirContext ctx); //===---------------------------------------------------------------------===// @@ -63,6 +69,8 @@ mlirTransformOperationTypeGetOperationName(MlirType type); MLIR_CAPI_EXPORTED bool mlirTypeIsATransformParamType(MlirType type); +MLIR_CAPI_EXPORTED MlirTypeID mlirTransformParamTypeGetTypeID(void); + MLIR_CAPI_EXPORTED MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 5e0e56fc0..66cf20e1c 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -495,6 +495,13 @@ class mlir_type_subclass : public pure_subclass { .attr("replace")(superCls.attr("__name__"), captureTypeName); }); if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( getTypeIDFunction())(pybind11::cpp_function( diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index c7764f4e7..6b57e652a 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -27,7 +27,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyOpType = - mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType); + mlir_type_subclass(m, "AnyOpType", mlirTypeIsATransformAnyOpType, + mlirTransformAnyOpTypeGetTypeID); anyOpType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -41,7 +42,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyParamType = - mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType); + mlir_type_subclass(m, "AnyParamType", mlirTypeIsATransformAnyParamType, + mlirTransformAnyParamTypeGetTypeID); anyParamType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -55,7 +57,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto anyValueType = - mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType); + mlir_type_subclass(m, "AnyValueType", mlirTypeIsATransformAnyValueType, + mlirTransformAnyValueTypeGetTypeID); anyValueType.def_classmethod( "get", [](py::object cls, MlirContext ctx) { @@ -96,7 +99,8 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { //===-------------------------------------------------------------------===// auto paramType = - mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType); + mlir_type_subclass(m, "ParamType", mlirTypeIsATransformParamType, + mlirTransformParamTypeGetTypeID); paramType.def_classmethod( "get", [](py::object cls, MlirType type, MlirContext ctx) { diff --git a/mlir/lib/CAPI/Dialect/Transform.cpp b/mlir/lib/CAPI/Dialect/Transform.cpp index 3f7f8b8e2..5fd773572 100644 --- a/mlir/lib/CAPI/Dialect/Transform.cpp +++ b/mlir/lib/CAPI/Dialect/Transform.cpp @@ -25,6 +25,10 @@ bool mlirTypeIsATransformAnyOpType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyOpTypeGetTypeID(void) { + return wrap(transform::AnyOpType::getTypeID()); +} + MlirType mlirTransformAnyOpTypeGet(MlirContext ctx) { return wrap(transform::AnyOpType::get(unwrap(ctx))); } @@ -37,6 +41,10 @@ bool mlirTypeIsATransformAnyParamType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyParamTypeGetTypeID(void) { + return wrap(transform::AnyParamType::getTypeID()); +} + MlirType mlirTransformAnyParamTypeGet(MlirContext ctx) { return wrap(transform::AnyParamType::get(unwrap(ctx))); } @@ -49,6 +57,10 @@ bool mlirTypeIsATransformAnyValueType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformAnyValueTypeGetTypeID(void) { + return wrap(transform::AnyValueType::getTypeID()); +} + MlirType mlirTransformAnyValueTypeGet(MlirContext ctx) { return wrap(transform::AnyValueType::get(unwrap(ctx))); } @@ -76,13 +88,17 @@ MlirStringRef mlirTransformOperationTypeGetOperationName(MlirType type) { } //===---------------------------------------------------------------------===// -// AnyOpType +// ParamType //===---------------------------------------------------------------------===// bool mlirTypeIsATransformParamType(MlirType type) { return isa(unwrap(type)); } +MlirTypeID mlirTransformParamTypeGetTypeID(void) { + return wrap(transform::ParamType::getTypeID()); +} + MlirType mlirTransformParamTypeGet(MlirContext ctx, MlirType type) { return wrap(transform::ParamType::get(unwrap(ctx), unwrap(type))); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 585918afc..41d91cf67 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -166,6 +166,14 @@ declare_mlir_dialect_python_bindings( "../../include/mlir/Dialect/Transform/IR/TransformAttrs.td" ) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.extras + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + GEN_ENUM_BINDINGS + SOURCES + extras/dialects/transform/__init__.py) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/extras/dialects/transform/__init__.py new file mode 100644 index 000000000..9e3133243 --- /dev/null +++ b/mlir/python/mlir/extras/dialects/transform/__init__.py @@ -0,0 +1,148 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from __future__ import annotations +from typing import Callable, Optional, Sequence + +from .... import ir +from ....dialects import transform +from ....dialects.transform import structured + + +class Handle(ir.Value): + """ + Base class for wrappers around different types of transform handle with + methods to chain further transforms. + + The fields `children` and `parent` are used to capture the relation of + handles statically in order to enable further analysis. The payload + operation of a child handle is nested into a region of the payload operation + of the corresponding parent handle. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v) + self.parent = parent + self.children = children if children is not None else [] + + +@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) +@ir.register_value_caster(transform.OperationType.get_static_typeid()) +class OpHandle(Handle): + """ + Wrapper around a transform operation handle with methods to chain further + transforms. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + def match_ops( + self, + ops: str + | ir.OpView + | structured.MatchInterfaceEnum + | Sequence[str | ir.OpView], + ) -> OpHandle: + """ + Emits a `transform.structured.MatchOp`. + Returns a handle to payload ops that match the given names, types, or + interface. If only a single type is given, the value wrapped by the + resulting handle is populated with the respective type. + """ + # Handle interface. + if isinstance(ops, structured.MatchInterfaceEnum) or ( + isinstance(ops, str) and ops in structured.MatchInterfaceEnum.__members__ + ): + if isinstance(ops, str): + ops = structured.MatchInterfaceEnum[ops] + match_op = structured.MatchOp( + transform.AnyOpType.get(), + self, + interface=ops, + ) + + # Handle op name(s), either given directly as string or given as op. + else: + if isinstance(ops, str): + op_type = transform.OperationType.get(ops) + op_names = [ops] + elif isinstance(ops, Sequence): + op_type = transform.AnyOpType.get() + op_names = [ + op if isinstance(op, str) else op.OPERATION_NAME for op in ops + ] + else: + op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_names = [ops.OPERATION_NAME] + match_op = structured.MatchOp.match_op_names( + op_type, + self, + op_names, + ) + + handle = OpHandle(match_op.results_, parent=self) + self.children.append(handle) + return handle + + +def insert_transform_script( + block_or_insertion_point: ir.Block | ir.InsertionPoint, + script: Callable[[OpHandle], None], + dump_script: bool = False, +) -> None: + """ + Inserts the transform script of the schedule into the module. The script + should accept an instance of OpHandle as argument, which will be called with + the block arg of the newly created named_sequence op. + + Example: + This python code + ``` + module = ir.Module.create() + def test_match_ops_single(module: OpHandle): + module.match_ops(scf.ForOp) + insert_transform_script(module.body, script) + ``` + generates the following IR: + ``` + module { + transform.named_sequence @__transform_main(%arg0: !transform.any_op) { + ^bb0(%arg0: !transform.any_op): + %0 = transform.structured.match ops{["scf.for"]} in %arg0 + : (!transform.any_op) -> !transform.op<"scf.for"> + } + } + ``` + """ + if isinstance(block_or_insertion_point, ir.Block): + context = block_or_insertion_point.owner.context + insertion_point = ir.InsertionPoint.at_block_begin(block_or_insertion_point) + else: + context = block_or_insertion_point.block.owner.context + insertion_point = block_or_insertion_point + + with context, ir.Location.unknown(context): + with insertion_point: + named_sequence_op = transform.NamedSequenceOp( + "__transform_main", [transform.AnyOpType.get()], [] + ) + with ir.InsertionPoint(named_sequence_op.body): + script(named_sequence_op.bodyTarget) + transform.YieldOp([]) + + if dump_script: + print(named_sequence_op) From 5e1866d78313b0af87856b48b14d6a2ffa4577da Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 20 Dec 2023 12:18:58 -0800 Subject: [PATCH 656/915] [mlir][python] Make the Context/Operation capsule creation methods work as documented. (#76010) This fixes a longstanding bug in the `Context._CAPICreate` method whereby it was not taking ownership of the PyMlirContext wrapper when casting to a Python object. The result was minimally that all such contexts transferred in that way would leak. In addition, counter to the documentation for the `_CAPICreate` helper (see `mlir-c/Bindings/Python/Interop.h`) and the `forContext` / `forOperation` methods, we were silently upgrading any unknown context/operation pointer to steal-ownership semantics. This is dangerous and was causing some subtle bugs downstream where this facility is getting the most use. This patch corrects the semantics and will only do an ownership transfer for `_CAPICreate`, and it will further require that it is an ownership transfer (if already transferred, it was just silently succeeding). Removing the mis-aligned behavior made it clear where the downstream was doing the wrong thing. It also adds some `_testing_` functions to create unowned context and operation capsules so that this can be fully tested upstream, reworking the tests to verify the behavior. In some torture testing downstream, I was not able to trigger any memory corruption with the newly enforced semantics. When getting it wrong, a regular exception is raised. --- mlir/lib/Bindings/Python/IRCore.cpp | 78 +++++++++++++++++++++++++---- mlir/lib/Bindings/Python/IRModule.h | 19 ++++++- 2 files changed, 86 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5412c3dec..39757dfad 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) throw py::error_already_set(); - return forContext(rawContext).releaseObject(); + return stealExternalContext(rawContext).releaseObject(); } PyMlirContext *PyMlirContext::createNewContextForInit() { @@ -615,18 +615,35 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { - // Create. - PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); - liveContexts[context.ptr] = unownedContextWrapper; - return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); + throw std::runtime_error( + "Cannot use a context that is not owned by the Python bindings."); } + // Use existing. py::object pyRef = py::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } +PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) { + py::gil_scoped_acquire acquire; + auto &liveContexts = getLiveContexts(); + auto it = liveContexts.find(context.ptr); + if (it != liveContexts.end()) { + throw std::runtime_error( + "Cannot transfer ownership of the context to Python " + "as it is already owned by Python."); + } + + PyMlirContext *unownedContextWrapper = new PyMlirContext(context); + // Note that the default return value policy on cast is automatic_reference, + // which does not take ownership (delete will not be called). + // Just be explicit. + py::object pyRef = + py::cast(unownedContextWrapper, py::return_value_policy::take_ownership); + assert(pyRef && "cast to py::object failed"); + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); +} + PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; @@ -1145,6 +1162,18 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, return PyOperationRef(existing, std::move(pyRef)); } +PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef, + MlirOperation operation) { + auto &liveOperations = contextRef->liveOperations; + auto it = liveOperations.find(operation.ptr); + if (it != liveOperations.end()) { + throw std::runtime_error( + "Cannot transfer ownership of the operation to Python " + "as it is already owned by Python."); + } + return createInstance(std::move(contextRef), operation, py::none()); +} + PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, py::object parentKeepAlive) { @@ -1316,7 +1345,8 @@ py::object PyOperation::createFromCapsule(py::object capsule) { if (mlirOperationIsNull(rawOperation)) throw py::error_already_set(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); - return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) + return stealExternalOperation(PyMlirContext::forContext(rawCtxt), + rawOperation) .releaseObject(); } @@ -2548,6 +2578,16 @@ void mlir::python::populateIRCore(py::module &m) { .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) + .def_static("_testing_create_raw_context_capsule", + []() { + // Creates an MlirContext not known to the Python bindings + // and puts it in a capsule. Used to test interop. Using + // this without passing it back to the capsule creation + // API will leak. + return py::reinterpret_steal( + mlirPythonContextToCapsule( + mlirContextCreateWithThreading(false))); + }) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) @@ -2973,8 +3013,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool>( - &PyOperationBase::print), + bool, py::object, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3046,6 +3085,25 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) + .def_static( + "_testing_create_raw_capsule", + [](std::string sourceStr) { + // Creates a raw context and an operation via parsing the given + // source and returns them in a capsule. Error handling is + // minimal as this is purely intended for testing interop with + // operation creation from capsule functions. + MlirContext context = mlirContextCreateWithThreading(false); + MlirOperation op = mlirOperationCreateParse( + context, toMlirStringRef(sourceStr), toMlirStringRef("temp")); + if (mlirOperationIsNull(op)) { + mlirContextDestroy(context); + throw std::invalid_argument("Failed to parse"); + } + return py::make_tuple(py::reinterpret_steal( + mlirPythonContextToCapsule(context)), + py::reinterpret_steal( + mlirPythonOperationToCapsule(op))); + }) .def_property_readonly("operation", [](py::object self) { return self; }) .def_property_readonly("opview", &PyOperation::createOpView) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 79b7e0c96..04164b78b 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -176,8 +176,19 @@ class PyMlirContext { static PyMlirContext *createNewContextForInit(); /// Returns a context reference for the singleton PyMlirContext wrapper for - /// the given context. + /// the given context. It is only valid to call this on an MlirContext that + /// is already owned by the Python bindings. Typically this will be because + /// it came in some fashion from createNewContextForInit(). However, it + /// is also possible to explicitly transfer ownership of an existing + /// MlirContext to the Python bindings via stealExternalContext(). static PyMlirContextRef forContext(MlirContext context); + + /// Explicitly takes ownership of an MlirContext that must not already be + /// known to the Python bindings. Once done, the life-cycle of the context + /// will be controlled by the Python bindings, and it will be destroyed + /// when the reference count goes to zero. + static PyMlirContextRef stealExternalContext(MlirContext context); + ~PyMlirContext(); /// Accesses the underlying MlirContext. @@ -606,6 +617,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject { forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); + /// Explicitly takes ownership of an operation that must not already be known + /// to the Python bindings. Once done, the life-cycle of the operation + /// will be controlled by the Python bindings. + static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef, + MlirOperation operation); + /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef From b3904c6d3113423da5564ab4edc181e1fa6ae952 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 20 Dec 2023 17:29:11 -0600 Subject: [PATCH 657/915] [mlir][python] move transform extras (#76102) --- mlir/python/CMakeLists.txt | 2 +- .../mlir/dialects/transform/__init__.py | 1 + .../transform/extras}/__init__.py | 43 ++++++++++--------- 3 files changed, 24 insertions(+), 22 deletions(-) rename mlir/python/mlir/{extras/dialects/transform => dialects/transform/extras}/__init__.py (80%) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 41d91cf67..55c5973e4 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -172,7 +172,7 @@ declare_mlir_python_sources( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" GEN_ENUM_BINDINGS SOURCES - extras/dialects/transform/__init__.py) + dialects/transform/extras/__init__.py) declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 7ae4fefba..175634c7d 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -6,6 +6,7 @@ from .._transform_ops_gen import * from .._transform_ops_gen import _Dialect from ..._mlir_libs._mlirDialectsTransform import * +from ..._mlir_libs._mlirDialectsTransform import AnyOpType, OperationType try: from ...ir import * diff --git a/mlir/python/mlir/extras/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py similarity index 80% rename from mlir/python/mlir/extras/dialects/transform/__init__.py rename to mlir/python/mlir/dialects/transform/extras/__init__.py index 9e3133243..c715dac1e 100644 --- a/mlir/python/mlir/extras/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -2,12 +2,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from __future__ import annotations -from typing import Callable, Optional, Sequence +from typing import Callable, Optional, Sequence, Union from .... import ir -from ....dialects import transform -from ....dialects.transform import structured +from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp +from .. import structured class Handle(ir.Value): @@ -25,16 +24,16 @@ def __init__( self, v: ir.Value, *, - parent: Optional[Handle] = None, - children: Optional[Sequence[Handle]] = None, + parent: Optional["Handle"] = None, + children: Optional[Sequence["Handle"]] = None, ): super().__init__(v) self.parent = parent self.children = children if children is not None else [] -@ir.register_value_caster(transform.AnyOpType.get_static_typeid()) -@ir.register_value_caster(transform.OperationType.get_static_typeid()) +@ir.register_value_caster(AnyOpType.get_static_typeid()) +@ir.register_value_caster(OperationType.get_static_typeid()) class OpHandle(Handle): """ Wrapper around a transform operation handle with methods to chain further @@ -52,11 +51,13 @@ def __init__( def match_ops( self, - ops: str - | ir.OpView - | structured.MatchInterfaceEnum - | Sequence[str | ir.OpView], - ) -> OpHandle: + ops: Union[ + str, + ir.OpView, + structured.MatchInterfaceEnum, + Sequence[Union[str, ir.OpView]], + ], + ) -> "OpHandle": """ Emits a `transform.structured.MatchOp`. Returns a handle to payload ops that match the given names, types, or @@ -70,7 +71,7 @@ def match_ops( if isinstance(ops, str): ops = structured.MatchInterfaceEnum[ops] match_op = structured.MatchOp( - transform.AnyOpType.get(), + AnyOpType.get(), self, interface=ops, ) @@ -78,15 +79,15 @@ def match_ops( # Handle op name(s), either given directly as string or given as op. else: if isinstance(ops, str): - op_type = transform.OperationType.get(ops) + op_type = OperationType.get(ops) op_names = [ops] elif isinstance(ops, Sequence): - op_type = transform.AnyOpType.get() + op_type = AnyOpType.get() op_names = [ op if isinstance(op, str) else op.OPERATION_NAME for op in ops ] else: - op_type = transform.OperationType.get(ops.OPERATION_NAME) + op_type = OperationType.get(ops.OPERATION_NAME) op_names = [ops.OPERATION_NAME] match_op = structured.MatchOp.match_op_names( op_type, @@ -100,7 +101,7 @@ def match_ops( def insert_transform_script( - block_or_insertion_point: ir.Block | ir.InsertionPoint, + block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], script: Callable[[OpHandle], None], dump_script: bool = False, ) -> None: @@ -137,12 +138,12 @@ def test_match_ops_single(module: OpHandle): with context, ir.Location.unknown(context): with insertion_point: - named_sequence_op = transform.NamedSequenceOp( - "__transform_main", [transform.AnyOpType.get()], [] + named_sequence_op = NamedSequenceOp( + "__transform_main", [AnyOpType.get()], [] ) with ir.InsertionPoint(named_sequence_op.body): script(named_sequence_op.bodyTarget) - transform.YieldOp([]) + YieldOp([]) if dump_script: print(named_sequence_op) From 822db6dc413eaae16e4fee1189599a7681c72cd4 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Thu, 21 Dec 2023 10:01:44 +0000 Subject: [PATCH 658/915] Revert "[mlir][python] Make the Context/Operation capsule creation methods work as documented. (#76010)" This reverts commit 5e1866d78313b0af87856b48b14d6a2ffa4577da. This change seems to be at odds with the non-owning part semantics of MlirOperation in C API. Since downstream clients can only take and return MlirOperation, it does not sound correct to force all returns of MlirOperation transfer ownership. Specifically, this makes it impossible for downstreams to implement IR-traversing functions that, e.g., look at neighbors of an operation. The following patch triggers the exception, and there does not seem to be an alternative way for a downstream binding writer to express this: ``` diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 39757dfad5be..2ce640674245 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3071,6 +3071,11 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("successors") = py::none(), py::arg("regions") = 0, py::arg("loc") = py::none(), py::arg("ip") = py::none(), py::arg("infer_type") = false, kOperationCreateDocstring) + .def("_get_first_in_block", [](PyOperation &self) -> MlirOperation { + MlirBlock block = mlirOperationGetBlock(self.get()); + MlirOperation first = mlirBlockGetFirstOperation(block); + return first; + }) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index f59b1a26ba48..6b12b8da5c24 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -24,6 +24,25 @@ def expect_index_error(callback): except IndexError: pass +@run +def testCustomBind(): + ctx = Context() + ctx.allow_unregistered_dialects = True + module = Module.parse( + r""" + func.func @f1(%arg0: i32) -> i32 { + %1 = "custom.addi"(%arg0, %arg0) : (i32, i32) -> i32 + return %1 : i32 + } + """, + ctx, + ) + add = module.body.operations[0].regions[0].blocks[0].operations[0] + op = add.operation + # This will get a reference to itself. + f1 = op._get_first_in_block() + + # Verify iterator based traversal of the op/region/block hierarchy. # CHECK-LABEL: TEST: testTraverseOpRegionBlockIterators ``` --- mlir/lib/Bindings/Python/IRCore.cpp | 78 ++++------------------------- mlir/lib/Bindings/Python/IRModule.h | 19 +------ 2 files changed, 11 insertions(+), 86 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 39757dfad..5412c3dec 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -602,7 +602,7 @@ py::object PyMlirContext::createFromCapsule(py::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) throw py::error_already_set(); - return stealExternalContext(rawContext).releaseObject(); + return forContext(rawContext).releaseObject(); } PyMlirContext *PyMlirContext::createNewContextForInit() { @@ -615,35 +615,18 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { - throw std::runtime_error( - "Cannot use a context that is not owned by the Python bindings."); + // Create. + PyMlirContext *unownedContextWrapper = new PyMlirContext(context); + py::object pyRef = py::cast(unownedContextWrapper); + assert(pyRef && "cast to py::object failed"); + liveContexts[context.ptr] = unownedContextWrapper; + return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } - // Use existing. py::object pyRef = py::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } -PyMlirContextRef PyMlirContext::stealExternalContext(MlirContext context) { - py::gil_scoped_acquire acquire; - auto &liveContexts = getLiveContexts(); - auto it = liveContexts.find(context.ptr); - if (it != liveContexts.end()) { - throw std::runtime_error( - "Cannot transfer ownership of the context to Python " - "as it is already owned by Python."); - } - - PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - py::object pyRef = - py::cast(unownedContextWrapper, py::return_value_policy::take_ownership); - assert(pyRef && "cast to py::object failed"); - return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); -} - PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; @@ -1162,18 +1145,6 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, return PyOperationRef(existing, std::move(pyRef)); } -PyOperationRef PyOperation::stealExternalOperation(PyMlirContextRef contextRef, - MlirOperation operation) { - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it != liveOperations.end()) { - throw std::runtime_error( - "Cannot transfer ownership of the operation to Python " - "as it is already owned by Python."); - } - return createInstance(std::move(contextRef), operation, py::none()); -} - PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, py::object parentKeepAlive) { @@ -1345,8 +1316,7 @@ py::object PyOperation::createFromCapsule(py::object capsule) { if (mlirOperationIsNull(rawOperation)) throw py::error_already_set(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); - return stealExternalOperation(PyMlirContext::forContext(rawCtxt), - rawOperation) + return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) .releaseObject(); } @@ -2578,16 +2548,6 @@ void mlir::python::populateIRCore(py::module &m) { .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_static("_testing_create_raw_context_capsule", - []() { - // Creates an MlirContext not known to the Python bindings - // and puts it in a capsule. Used to test interop. Using - // this without passing it back to the capsule creation - // API will leak. - return py::reinterpret_steal( - mlirPythonContextToCapsule( - mlirContextCreateWithThreading(false))); - }) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) @@ -3013,7 +2973,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool>(&PyOperationBase::print), + bool, py::object, bool>( + &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3085,25 +3046,6 @@ void mlir::python::populateIRCore(py::module &m) { .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_static( - "_testing_create_raw_capsule", - [](std::string sourceStr) { - // Creates a raw context and an operation via parsing the given - // source and returns them in a capsule. Error handling is - // minimal as this is purely intended for testing interop with - // operation creation from capsule functions. - MlirContext context = mlirContextCreateWithThreading(false); - MlirOperation op = mlirOperationCreateParse( - context, toMlirStringRef(sourceStr), toMlirStringRef("temp")); - if (mlirOperationIsNull(op)) { - mlirContextDestroy(context); - throw std::invalid_argument("Failed to parse"); - } - return py::make_tuple(py::reinterpret_steal( - mlirPythonContextToCapsule(context)), - py::reinterpret_steal( - mlirPythonOperationToCapsule(op))); - }) .def_property_readonly("operation", [](py::object self) { return self; }) .def_property_readonly("opview", &PyOperation::createOpView) .def_property_readonly( diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 04164b78b..79b7e0c96 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -176,19 +176,8 @@ class PyMlirContext { static PyMlirContext *createNewContextForInit(); /// Returns a context reference for the singleton PyMlirContext wrapper for - /// the given context. It is only valid to call this on an MlirContext that - /// is already owned by the Python bindings. Typically this will be because - /// it came in some fashion from createNewContextForInit(). However, it - /// is also possible to explicitly transfer ownership of an existing - /// MlirContext to the Python bindings via stealExternalContext(). + /// the given context. static PyMlirContextRef forContext(MlirContext context); - - /// Explicitly takes ownership of an MlirContext that must not already be - /// known to the Python bindings. Once done, the life-cycle of the context - /// will be controlled by the Python bindings, and it will be destroyed - /// when the reference count goes to zero. - static PyMlirContextRef stealExternalContext(MlirContext context); - ~PyMlirContext(); /// Accesses the underlying MlirContext. @@ -617,12 +606,6 @@ class PyOperation : public PyOperationBase, public BaseContextObject { forOperation(PyMlirContextRef contextRef, MlirOperation operation, pybind11::object parentKeepAlive = pybind11::object()); - /// Explicitly takes ownership of an operation that must not already be known - /// to the Python bindings. Once done, the life-cycle of the operation - /// will be controlled by the Python bindings. - static PyOperationRef stealExternalOperation(PyMlirContextRef contextRef, - MlirOperation operation); - /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef From a4b4883d908f8194628996853d2fa82d01229e1c Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 21 Dec 2023 11:20:29 -0600 Subject: [PATCH 659/915] [mlir][python] meta region_op (#75673) --- mlir/python/CMakeLists.txt | 9 +- mlir/python/mlir/dialects/arith.py | 8 ++ mlir/python/mlir/dialects/builtin.py | 23 +++++ mlir/python/mlir/dialects/func.py | 3 + mlir/python/mlir/dialects/pdl.py | 10 ++- mlir/python/mlir/dialects/scf.py | 2 +- mlir/python/mlir/dialects/tensor.py | 7 ++ .../mlir/dialects/transform/__init__.py | 13 ++- .../dialects/transform/extras/__init__.py | 15 +++- mlir/python/mlir/extras/meta.py | 83 +++++++++++++++++++ 10 files changed, 166 insertions(+), 7 deletions(-) create mode 100644 mlir/python/mlir/extras/meta.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 55c5973e4..3c9cf304d 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -21,7 +21,6 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/__init__.py ir.py passmanager.py - extras/types.py dialects/_ods_common.py # The main _mlir module has submodules: include stubs from each. @@ -30,6 +29,14 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/_mlir/passmanager.pyi ) +declare_mlir_python_sources(MLIRPythonSources.Core.Python.Extras + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + ADD_TO_PARENT MLIRPythonSources.Core.Python + SOURCES + extras/types.py + extras/meta.py +) + declare_mlir_python_sources(MLIRPythonSources.ExecutionEngine ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonSources diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 83aca0d58..663a53660 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -11,6 +11,8 @@ from ._ods_common import ( get_default_loc_context as _get_default_loc_context, _cext as _ods_cext, + get_op_result_or_op_results as _get_op_result_or_op_results, + SubClassValueT as _SubClassValueT, ) from typing import Any, List, Union @@ -75,3 +77,9 @@ def literal_value(self) -> Union[int, float]: return FloatAttr(self.value).value else: raise ValueError("only integer and float constants have literal values") + + +def constant( + result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None +) -> _SubClassValueT: + return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) diff --git a/mlir/python/mlir/dialects/builtin.py b/mlir/python/mlir/dialects/builtin.py index b71cc2466..1c69d6d7c 100644 --- a/mlir/python/mlir/dialects/builtin.py +++ b/mlir/python/mlir/dialects/builtin.py @@ -2,8 +2,11 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Dict, Optional + from ._builtin_ops_gen import * from ._builtin_ops_gen import _Dialect +from ..extras.meta import region_op try: from ..ir import * @@ -23,3 +26,23 @@ def __init__(self, *, loc=None, ip=None): @property def body(self): return self.regions[0].blocks[0] + + +@region_op +def module( + *, + sym_name=None, + sym_visibility=None, + attrs: Optional[Dict[str, Attribute]] = None, + loc=None, + ip=None, +): + mod = ModuleOp.__base__( + sym_name=sym_name, sym_visibility=sym_visibility, loc=loc, ip=ip + ) + if attrs is None: + attrs = {} + for attr_name, attr in attrs.items(): + mod.operation.attributes[attr_name] = attr + + return mod diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 6599f67b7..24fdcbcd8 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -243,6 +243,9 @@ def emit_call_op(*call_args): return decorator +func = FuncOp.from_py_func + + @_ods_cext.register_operation(_Dialect, replace=True) class CallOp(CallOp): """Specialization for the call op class.""" diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index 90d7d7062..db07dc50a 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -5,6 +5,7 @@ from ._pdl_ops_gen import * from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * +from .._mlir_libs._mlirDialectsPDL import OperationType try: @@ -13,7 +14,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Union, Optional, Sequence, Mapping +from typing import Union, Optional, Sequence, Mapping, NewType from ._ods_common import ( get_op_result_or_value as _get_value, get_op_results_or_values as _get_values, @@ -220,3 +221,10 @@ def __init__( constantTypes = [] result = pdl.RangeType.get(pdl.TypeType.get()) super().__init__(result, constantTypes=constantTypes, loc=loc, ip=ip) + + +OperationTypeT = NewType("OperationType", OperationType) + + +def op_t() -> OperationTypeT: + return OperationTypeT(OperationType.get()) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 20bbed9bc..dad737798 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -120,7 +120,7 @@ def for_( params = [start, stop, step] for i, p in enumerate(params): if isinstance(p, int): - p = constant(IntegerAttr.get(IndexType.get(), p)) + p = constant(IndexType.get(), p) elif isinstance(p, float): raise ValueError(f"{p=} must be int.") params[i] = p diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 67248748e..79dd9476a 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -4,6 +4,7 @@ from ._tensor_ops_gen import * from ._tensor_ops_gen import _Dialect +from ..extras.meta import region_op try: from ..ir import * @@ -40,3 +41,9 @@ def __init__( dynamic_sizes.append(s) result_type = RankedTensorType.get(static_sizes, element_type) super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) + + +generate = region_op( + lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), + terminator=lambda args: YieldOp(args[0]), +) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 175634c7d..5b158ec6b 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,7 +18,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union +from typing import Optional, Sequence, Union, NewType @_ods_cext.register_operation(_Dialect, replace=True) @@ -175,7 +175,7 @@ def __init__( result_types: Sequence[Type], sym_visibility=None, arg_attrs=None, - res_attrs=None + res_attrs=None, ): function_type = FunctionType.get(input_types, result_types) super().__init__( @@ -183,7 +183,7 @@ def __init__( function_type=TypeAttr.get(function_type), sym_visibility=sym_visibility, arg_attrs=arg_attrs, - res_attrs=res_attrs + res_attrs=res_attrs, ) self.regions[0].blocks.append(*input_types) @@ -212,3 +212,10 @@ def __init__( if operands is None: operands = [] super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) + + +AnyOpTypeT = NewType("AnyOpType", AnyOpType) + + +def any_op_t() -> AnyOpTypeT: + return AnyOpTypeT(AnyOpType.get()) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index c715dac1e..e4d47e906 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -4,8 +4,16 @@ from typing import Callable, Optional, Sequence, Union +from ....extras.meta import region_op from .... import ir -from .. import AnyOpType, OperationType, NamedSequenceOp, YieldOp +from .. import ( + AnyOpType, + OperationType, + NamedSequenceOp, + YieldOp, + SequenceOp, + ApplyPatternsOp, +) from .. import structured @@ -147,3 +155,8 @@ def test_match_ops_single(module: OpHandle): if dump_script: print(named_sequence_op) + + +sequence = region_op(SequenceOp.__base__, terminator=YieldOp) +named_sequence = region_op(NamedSequenceOp, terminator=YieldOp) +apply_patterns = region_op(ApplyPatternsOp) diff --git a/mlir/python/mlir/extras/meta.py b/mlir/python/mlir/extras/meta.py new file mode 100644 index 000000000..3f2defadf --- /dev/null +++ b/mlir/python/mlir/extras/meta.py @@ -0,0 +1,83 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import inspect +from functools import wraps + +from ..dialects._ods_common import get_op_result_or_op_results +from ..ir import Type, InsertionPoint + + +def op_region_builder(op, op_region, terminator=None): + def builder_wrapper(body_builder): + # Add a block with block args having types determined by type hints on the wrapped function. + if len(op_region.blocks) == 0: + sig = inspect.signature(body_builder) + types = [p.annotation for p in sig.parameters.values()] + if not ( + len(types) == len(sig.parameters) + and all(isinstance(t, Type) for t in types) + ): + raise ValueError( + f"for {body_builder=} either missing a type annotation or type annotation isn't a mlir type: {sig}" + ) + + op_region.blocks.append(*types) + + with InsertionPoint(op_region.blocks[0]): + results = body_builder(*list(op_region.blocks[0].arguments)) + + with InsertionPoint(list(op_region.blocks)[-1]): + if terminator is not None: + res = [] + if isinstance(results, (tuple, list)): + res.extend(results) + elif results is not None: + res.append(results) + terminator(res) + + return get_op_result_or_op_results(op) + + return builder_wrapper + + +def region_op(op_constructor, terminator=None): + """Decorator to define an MLIR Op specified as a python function. + + Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are + active for the current thread (i.e. established in a `with` block). + + Supports "naked" usage i.e., no parens if no args need to be passed to the Op constructor. + + When applied as a decorator to a Python function, an entry block will + be constructed for the Op with types as specified **as type hints on the args of the function**. + The block arguments will be passed positionally to the Python function. + + If a terminator is specified then the return from the decorated function will be passed + to the terminator as the last statement in the entry block. Note, the API for the terminator + is a (possibly empty) list; terminator accepting single values should be wrapped in a + `lambda args: term(args[0])` + + The identifier (name) of the function will become: + 1. A single value result if the Op returns a single value; + 2. An OpResultList (as a list) if the Op returns multiple values; + 3. The Operation if the Op returns no results. + + See examples in tensor.py and transform.extras. + """ + + def op_decorator(*args, **kwargs): + op = op_constructor(*args, **kwargs) + op_region = op.regions[0] + + return op_region_builder(op, op_region, terminator) + + @wraps(op_decorator) + def maybe_no_args(*args, **kwargs): + if len(args) == 1 and len(kwargs) == 0 and callable(args[0]): + return op_decorator()(args[0]) + else: + return op_decorator(*args, **kwargs) + + return maybe_no_args From 01e8c3167bda12f6489aab1f4e1c1f8a92990715 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 2 Jan 2024 16:11:44 +0000 Subject: [PATCH 660/915] [mlir][spirv] Add support for C-API/python binding to SPIR-V dialect (#76055) Enable bindings. --------- Co-authored-by: jungpark-mlir --- mlir/include/mlir-c/Dialect/SPIRV.h | 26 ++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/SPIRV.cpp | 13 +++++++++++++ mlir/python/CMakeLists.txt | 7 +++++++ mlir/python/mlir/dialects/SPIRVOps.td | 14 ++++++++++++++ mlir/python/mlir/dialects/spirv.py | 5 +++++ 6 files changed, 74 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/SPIRV.h create mode 100644 mlir/lib/CAPI/Dialect/SPIRV.cpp create mode 100644 mlir/python/mlir/dialects/SPIRVOps.td create mode 100644 mlir/python/mlir/dialects/spirv.py diff --git a/mlir/include/mlir-c/Dialect/SPIRV.h b/mlir/include/mlir-c/Dialect/SPIRV.h new file mode 100644 index 000000000..f22708c9d --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SPIRV.h @@ -0,0 +1,26 @@ +//===-- mlir-c/Dialect/SPIRV.h - C API for SPIRV dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SPIRV_H +#define MLIR_C_DIALECT_SPIRV_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SPIRV, spirv); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_SPIRV_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index d815eba48..b2952da17 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -171,6 +171,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIFunc MLIRFuncDialect ) +add_mlir_upstream_c_api_library(MLIRCAPISPIRV + SPIRV.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRSPIRVDialect +) + add_mlir_upstream_c_api_library(MLIRCAPITensor Tensor.cpp diff --git a/mlir/lib/CAPI/Dialect/SPIRV.cpp b/mlir/lib/CAPI/Dialect/SPIRV.cpp new file mode 100644 index 000000000..9bfe26b95 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SPIRV.cpp @@ -0,0 +1,13 @@ +//===- SPIRV.cpp - C Interface for SPIRV dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SPIRV.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SPIRV, spirv, mlir::spirv::SPIRVDialect) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 3c9cf304d..266b86090 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -378,6 +378,13 @@ declare_mlir_dialect_python_bindings( "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SPIRVOps.td + SOURCES dialects/spirv.py + DIALECT_NAME spirv) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/SPIRVOps.td b/mlir/python/mlir/dialects/SPIRVOps.td new file mode 100644 index 000000000..eaae0e609 --- /dev/null +++ b/mlir/python/mlir/dialects/SPIRVOps.td @@ -0,0 +1,14 @@ +//===-- SPIRVOps.td - Entry point for SPIRVOps bind --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_SPIRV_OPS +#define PYTHON_BINDINGS_SPIRV_OPS + +include "mlir/Dialect/SPIRV/IR/SPIRVOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/spirv.py b/mlir/python/mlir/dialects/spirv.py new file mode 100644 index 000000000..269678a20 --- /dev/null +++ b/mlir/python/mlir/dialects/spirv.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._spirv_ops_gen import * From 21245f0d4b09b5d7b3cf41c9b27e3e7396df7b62 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 3 Jan 2024 16:33:27 +0100 Subject: [PATCH 661/915] [mlir] fix Operation::getDiscardableAttrs in absence of properties (#76816) When properties are not enabled in an operation, inherent attributes are stored in the common dictionary with discardable attributes. However, `getDiscardableAttrs` and `getDiscardableAttrDictionary` were returning the entire dictionary, making the caller mistakenly believe that all inherent attributes are discardable. Fix this by filtering out attributes whose names are registered with the operation, i.e., inherent attributes. This requires an API change so `getDiscardableAttrs` returns a filter range. --- mlir/lib/CAPI/IR/IR.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index ac9889df1..a97cfe5b0 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -613,12 +613,14 @@ void mlirOperationSetInherentAttributeByName(MlirOperation op, } intptr_t mlirOperationGetNumDiscardableAttributes(MlirOperation op) { - return static_cast(unwrap(op)->getDiscardableAttrs().size()); + return static_cast( + llvm::range_size(unwrap(op)->getDiscardableAttrs())); } MlirNamedAttribute mlirOperationGetDiscardableAttribute(MlirOperation op, intptr_t pos) { - NamedAttribute attr = unwrap(op)->getDiscardableAttrs()[pos]; + NamedAttribute attr = + *std::next(unwrap(op)->getDiscardableAttrs().begin(), pos); return MlirNamedAttribute{wrap(attr.getName()), wrap(attr.getValue())}; } From 9a627b15f3c40b237043401e08c7bac88c6e79e7 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Thu, 4 Jan 2024 08:49:57 +0000 Subject: [PATCH 662/915] Slightly improved ir.pyi type annotations (#76728) * Replaced `Any` with static types where appropriate * Removed undocumented `__str__` and `__repr__` -- these are always defined via `object` --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 84 +++--------------------- 1 file changed, 10 insertions(+), 74 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index fa591e5f1..57a85990f 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -460,11 +460,9 @@ class AffineExpr: @overload def __mul__(self, arg0: int) -> AffineMulExpr: ... def __radd__(self, arg0: int) -> AffineAddExpr: ... - def __repr__(self) -> str: ... def __rmod__(self, arg0: int) -> AffineModExpr: ... def __rmul__(self, arg0: int) -> AffineMulExpr: ... def __rsub__(self, arg0: int) -> AffineAddExpr: ... - def __str__(self) -> str: ... @overload def __sub__(self, arg0: AffineExpr) -> AffineAddExpr: ... @overload @@ -495,7 +493,6 @@ class Attribute: """ Casts the passed attribute to the generic Attribute """ - def __repr__(self) -> str: ... def __str__(self) -> str: """ Returns the assembly form of the Attribute. @@ -541,7 +538,6 @@ class Type: """ Casts the passed type to the generic Type """ - def __repr__(self) -> str: ... def __str__(self) -> str: """ Returns the assembly form of the type. @@ -710,8 +706,6 @@ class AffineMap: @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - def __str__(self) -> str: ... def dump(self) -> None: """ Dumps a debug representation of the object to stderr. @@ -756,7 +750,6 @@ class AffineMapAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -801,7 +794,6 @@ class ArrayAttr(Attribute): self, ) -> ArrayAttributeIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -840,7 +832,6 @@ class BF16Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -951,7 +942,6 @@ class BoolAttr(Attribute): Converts the value of the bool attribute to a Python bool """ def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -974,7 +964,6 @@ class ComplexType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def element_type(self) -> Type: """ @@ -989,7 +978,7 @@ class Context: @staticmethod def _get_live_count() -> int: ... def _CAPICreate(self) -> object: ... - def __enter__(self) -> Any: ... + def __enter__(self) -> Context: ... def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... def __init__(self) -> None: ... def _clear_live_operations(self) -> int: ... @@ -1040,7 +1029,6 @@ class DenseBoolArrayAttr(Attribute): self, ) -> DenseBoolArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1111,7 +1099,6 @@ class DenseElementsAttr(Attribute): def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... def get_splat_value(self) -> Attribute: ... @property def is_splat(self) -> bool: ... @@ -1139,7 +1126,6 @@ class DenseF32ArrayAttr(Attribute): self, ) -> DenseF32ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1168,7 +1154,6 @@ class DenseF64ArrayAttr(Attribute): self, ) -> DenseF64ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1185,7 +1170,6 @@ class DenseFPElementsAttr(DenseElementsAttr): def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> float: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1208,7 +1192,6 @@ class DenseI16ArrayAttr(Attribute): self, ) -> DenseI16ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1235,7 +1218,6 @@ class DenseI32ArrayAttr(Attribute): self, ) -> DenseI32ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1262,7 +1244,6 @@ class DenseI64ArrayAttr(Attribute): self, ) -> DenseI16ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1289,7 +1270,6 @@ class DenseI8ArrayAttr(Attribute): self, ) -> DenseI8ArrayIterator: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1306,7 +1286,6 @@ class DenseIntElementsAttr(DenseElementsAttr): def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> int: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1352,7 +1331,6 @@ class DenseResourceElementsAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1361,7 +1339,6 @@ class DenseResourceElementsAttr(Attribute): def typeid(self) -> TypeID: ... class Diagnostic: - def __str__(self) -> str: ... @property def location(self) -> Location: ... @property @@ -1382,7 +1359,6 @@ class DiagnosticHandler: class DiagnosticInfo: def __init__(self, arg0: Diagnostic) -> None: ... - def __str__(self) -> str: ... @property def location(self) -> Location: ... @property @@ -1419,9 +1395,7 @@ class DiagnosticSeverity: def __init__(self, value: int) -> None: ... def __int__(self) -> int: ... def __ne__(self, other: Any) -> bool: ... - def __repr__(self) -> str: ... def __setstate__(self, state: int) -> None: ... - def __str__(self) -> str: ... @property def name(self) -> str: ... @property @@ -1429,12 +1403,10 @@ class DiagnosticSeverity: class Dialect: def __init__(self, descriptor: DialectDescriptor) -> None: ... - def __repr__(self) -> Any: ... @property def descriptor(self) -> DialectDescriptor: ... class DialectDescriptor: - def __repr__(self) -> str: ... @property def namespace(self) -> str: ... @@ -1445,8 +1417,8 @@ class DialectRegistry: def _CAPIPtr(self) -> object: ... class Dialects: - def __getattr__(self, arg0: str) -> Any: ... - def __getitem__(self, arg0: str) -> Any: ... + def __getattr__(self, arg0: str) -> Dialect: ... + def __getitem__(self, arg0: str) -> Dialect: ... class DictAttr(Attribute): static_typeid: ClassVar[TypeID] # value = @@ -1464,7 +1436,6 @@ class DictAttr(Attribute): def __getitem__(self, arg0: int) -> NamedAttribute: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __len__(self) -> int: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -1480,7 +1451,6 @@ class F16Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1494,7 +1464,6 @@ class F32Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1508,7 +1477,6 @@ class F64Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1521,7 +1489,6 @@ class FlatSymbolRefAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -1544,7 +1511,6 @@ class Float8E4M3B11FNUZType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1558,7 +1524,6 @@ class Float8E4M3FNType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1572,7 +1537,6 @@ class Float8E4M3FNUZType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1586,7 +1550,6 @@ class Float8E5M2FNUZType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1600,7 +1563,6 @@ class Float8E5M2Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1628,7 +1590,6 @@ class FloatAttr(Attribute): Converts the value of the float attribute to a Python float """ def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -1649,7 +1610,6 @@ class FloatTF32Type(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1665,7 +1625,6 @@ class FunctionType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def inputs(self) -> List: """ @@ -1689,7 +1648,6 @@ class IndexType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -1769,7 +1727,7 @@ class InsertionPoint: """ Inserts before the block terminator. """ - def __enter__(self) -> Any: ... + def __enter__(self) -> InsertionPoint: ... def __exit__(self, arg0: Any, arg1: Any, arg2: Any) -> None: ... @overload def __init__(self, block: Block) -> None: @@ -1810,7 +1768,6 @@ class IntegerAttr(Attribute): """ Converts the value of the integer attribute to a Python int """ - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -1840,8 +1797,6 @@ class IntegerSet: @overload def __eq__(self, arg0: object) -> bool: ... def __hash__(self) -> int: ... - def __repr__(self) -> str: ... - def __str__(self) -> str: ... def dump(self) -> None: """ Dumps a debug representation of the object to stderr. @@ -1908,7 +1863,6 @@ class IntegerType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def is_signed(self) -> bool: """ @@ -1984,7 +1938,6 @@ class Location: @overload def __eq__(self, arg0: Location) -> bool: ... def __exit__(self, arg0: object, arg1: object, arg2: object) -> None: ... - def __repr__(self) -> str: ... def emit_error(self, message: str) -> None: """ Emits an error at this location @@ -2018,7 +1971,6 @@ class MemRefType(ShapedType): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def affine_map(self) -> AffineMap: """ @@ -2039,12 +1991,12 @@ class MemRefType(ShapedType): class Module: @staticmethod - def create(loc: Optional[Location] = None) -> Any: + def create(loc: Optional[Location] = None) -> Module: """ Creates an empty module """ @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Any: + def parse(asm: str, context: Optional[Context] = None) -> Module: """ Parses a module's assembly format from a string. @@ -2053,7 +2005,7 @@ class Module: See also: https://mlir.llvm.org/docs/LangRef/ """ def _CAPICreate(self) -> Any: ... - def __str__(self) -> Any: + def __str__(self) -> str: """ Gets the assembly form of the operation with default options. @@ -2078,7 +2030,7 @@ class Module: Context that created the Module """ @property - def operation(self) -> Any: + def operation(self) -> Operation: """ Accesses the module as an operation """ @@ -2089,7 +2041,6 @@ class MLIRError(Exception): ) -> None: ... class NamedAttribute: - def __repr__(self) -> str: ... @property def attr(self) -> Attribute: """ @@ -2111,7 +2062,6 @@ class NoneType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -2202,7 +2152,6 @@ class OpView(_OperationBase): Parses a specific, generated OpView based on class level attributes """ def __init__(self, operation: _OperationBase) -> None: ... - def __str__(self) -> str: ... @property def operation(self) -> _OperationBase: ... @property @@ -2228,7 +2177,6 @@ class OpaqueAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def data(self) -> bytes: """ @@ -2256,7 +2204,6 @@ class OpaqueType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def data(self) -> str: """ @@ -2305,7 +2252,7 @@ class Operation(_OperationBase): @staticmethod def parse( source: str, *, source_name: str = "", context: Optional[Context] = None - ) -> Any: + ) -> Operation: """ Parses an operation. Supports both text assembly format and binary bytecode format. """ @@ -2327,7 +2274,7 @@ class OperationIterator: def __next__(self) -> OpView: ... class OperationList: - def __getitem__(self, arg0: int) -> Any: ... + def __getitem__(self, arg0: int) -> OpView: ... def __iter__(self) -> OperationIterator: ... def __len__(self) -> int: ... @@ -2346,7 +2293,6 @@ class RankedTensorType(ShapedType): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def encoding(self) -> Optional[Attribute]: ... @property @@ -2401,7 +2347,6 @@ class ShapedType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... def get_dim_size(self, dim: int) -> int: """ Returns the dim-th dimension of the given ranked shaped type. @@ -2505,7 +2450,6 @@ class StridedLayoutAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def offset(self) -> int: """ @@ -2536,7 +2480,6 @@ class StringAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -2561,7 +2504,6 @@ class SymbolRefAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def static_typeid(self) -> TypeID: ... @property @@ -2610,7 +2552,6 @@ class TupleType(Type): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... def get_type(self, pos: int) -> Type: """ Returns the pos-th type in the Tuple type. @@ -2633,7 +2574,6 @@ class TypeAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -2661,7 +2601,6 @@ class UnitAttr(Attribute): @staticmethod def isinstance(other: Attribute) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... - def __repr__(self) -> str: ... @property def type(self) -> Type: ... @property @@ -2679,7 +2618,6 @@ class UnrankedMemRefType(ShapedType): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def memory_space(self) -> Optional[Attribute]: """ @@ -2698,7 +2636,6 @@ class UnrankedTensorType(ShapedType): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def typeid(self) -> TypeID: ... @@ -2719,7 +2656,6 @@ class VectorType(ShapedType): @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... - def __repr__(self) -> str: ... @property def scalable(self) -> bool: ... @property From 7e55a8e3301f59734d85772cf60ac692ce322101 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sat, 6 Jan 2024 16:42:14 -0600 Subject: [PATCH 663/915] [mlir][python] add MemRefTypeAttr attr builder (#76371) --- mlir/python/mlir/ir.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 6d21da3b4..eb7f035fe 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -263,6 +263,11 @@ def _typeArrayAttr(x, context): return _arrayAttr([TypeAttr.get(t, context=context) for t in x], context) +@register_attribute_builder("MemRefTypeAttr") +def _memref_type_attr(x, context): + return _typeAttr(x, context) + + try: import numpy as np From a3c7a42518fdb8e6bd7b150adada817fbdf215f0 Mon Sep 17 00:00:00 2001 From: martin-luecke Date: Mon, 15 Jan 2024 10:31:22 +0100 Subject: [PATCH 664/915] [MLIR][transform][python] Introduce abstractions for handles to values and parameters (#77305) In addition to the existing `OpHandle` which provides an abstraction to emit transform ops targeting operations this introduces a similar concept for _values_ and _parameters_ in form of `ValueHandle` and `ParamHandle`. New core transform abstractions: - `constant_param` - `OpHandle.get_result` - `OpHandle.print` - `ValueHandle.get_defining_op` --- .../dialects/transform/extras/__init__.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index e4d47e906..ba51c400f 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -6,9 +6,13 @@ from ....extras.meta import region_op from .... import ir +from ... import transform from .. import ( AnyOpType, + AnyParamType, + AnyValueType, OperationType, + ParamType, NamedSequenceOp, YieldOp, SequenceOp, @@ -57,6 +61,19 @@ def __init__( ): super().__init__(v, parent=parent, children=children) + def get_result(self, idx: int = 0) -> "ValueHandle": + """ + Emits a `transform.GetResultOp`. + Returns a handle to the result of the payload operation at the given + index. + """ + get_result_op = transform.GetResultOp( + AnyValueType.get(), + self, + idx, + ) + return get_result_op.result + def match_ops( self, ops: Union[ @@ -107,6 +124,74 @@ def match_ops( self.children.append(handle) return handle + def print(self, name: Optional[str] = None) -> "OpHandle": + """ + Emits a `transform.PrintOp` to print this handle and an optional message. + Returns the existing handle to facilitate further chaining. + """ + transform.PrintOp(target=self, name=name) + return self + + +@ir.register_value_caster(AnyParamType.get_static_typeid()) +@ir.register_value_caster(ParamType.get_static_typeid()) +class ParamHandle(Handle): + """Wrapper around a transform param handle.""" + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + +@ir.register_value_caster(AnyValueType.get_static_typeid()) +class ValueHandle(Handle): + """ + Wrapper around a transform value handle with methods to chain further + transforms. + """ + + def __init__( + self, + v: ir.Value, + *, + parent: Optional[Handle] = None, + children: Optional[Sequence[Handle]] = None, + ): + super().__init__(v, parent=parent, children=children) + + def get_defining_op(self) -> OpHandle: + """ + Emits a `transform.GetDefiningOpOp`. + Returns a handle to the defining op of the wrapped value. + """ + get_defining_op = transform.GetDefiningOp( + AnyOpType.get(), + self, + ) + return get_defining_op.result + + +def constant_param(value: Union[ir.Attribute, int]) -> ParamHandle: + """ + Emits a `transform.ParamConstantOp`. + Returns a handle to the newly created parameter. The type of the parameter + is `transfrom.any_param` if the value is not an integer, otherwise the type + is `transform.param` parametrized with the according integer type. + """ + if isinstance(value, int): + value = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), value) + if isinstance(value.type, ir.IntegerType): + param_type = ParamType.get(value.type) + else: + param_type = AnyParamType.get() + op = transform.ParamConstantOp(param_type, value) + return op.param + def insert_transform_script( block_or_insertion_point: Union[ir.Block, ir.InsertionPoint], From e6e78dffb2b0810fda4751add4bc66dc1c3dd3d5 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Thu, 18 Jan 2024 06:33:14 -0800 Subject: [PATCH 665/915] [mlir][transform] Add transform.get_operand op (#78397) Similar to `transform.get_result`, except it returns a handle to the operand indicated by a positional specification, same as is defined for the linalg match ops. Additionally updates `get_result` to take the same positional specification. This makes the use case of wanting to get all of the results of an operation easier by no longer requiring the user to reconstruct the list of results one-by-one. --- mlir/python/mlir/dialects/transform/extras/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/extras/__init__.py b/mlir/python/mlir/dialects/transform/extras/__init__.py index ba51c400f..8d045cad7 100644 --- a/mlir/python/mlir/dialects/transform/extras/__init__.py +++ b/mlir/python/mlir/dialects/transform/extras/__init__.py @@ -43,7 +43,6 @@ def __init__( self.parent = parent self.children = children if children is not None else [] - @ir.register_value_caster(AnyOpType.get_static_typeid()) @ir.register_value_caster(OperationType.get_static_typeid()) class OpHandle(Handle): @@ -61,16 +60,16 @@ def __init__( ): super().__init__(v, parent=parent, children=children) - def get_result(self, idx: int = 0) -> "ValueHandle": + def get_result(self, indices: Sequence[int] = [0]) -> "ValueHandle": """ Emits a `transform.GetResultOp`. Returns a handle to the result of the payload operation at the given - index. + indices. """ get_result_op = transform.GetResultOp( AnyValueType.get(), self, - idx, + indices, ) return get_result_op.result From 098943d251a4ebcf5a5de3058182d550417dbbba Mon Sep 17 00:00:00 2001 From: Rageking8 <106309953+Rageking8@users.noreply.github.com> Date: Sun, 28 Jan 2024 14:20:08 +0800 Subject: [PATCH 666/915] Fix unsigned typos (#76670) --- mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index d9698e8ab..23d6d26b7 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -318,7 +318,7 @@ class BinaryFn: Examples: - max -> `arith.MaxSIOp` - - max_unsinged -> `arith.MaxUIOp` + - max_unsigned -> `arith.MaxUIOp` """ add = BinaryFnType("add") From e8b2ab0790c7dabc23a8a0ea3c1cba48913168da Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 30 Jan 2024 16:21:56 -0600 Subject: [PATCH 667/915] [mlir][python] enable memref.subview (#79393) --- mlir/include/mlir-c/BuiltinTypes.h | 6 + mlir/lib/Bindings/Python/IRTypes.cpp | 14 ++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 16 ++ mlir/python/mlir/dialects/_ods_common.py | 174 +++++++++++++++++- mlir/python/mlir/dialects/memref.py | 130 +++++++++++++ .../mlir/dialects/transform/structured.py | 169 ++--------------- 6 files changed, 352 insertions(+), 157 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 1fd5691f4..881b6dad2 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -408,6 +408,12 @@ MLIR_CAPI_EXPORTED MlirAffineMap mlirMemRefTypeGetAffineMap(MlirType type); /// Returns the memory space of the given MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type); +/// Returns the strides of the MemRef if the layout map is in strided form. +/// Both strides and offset are out params. strides must point to pre-allocated +/// memory of length equal to the rank of the memref. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirMemRefTypeGetStridesAndOffset( + MlirType type, int64_t *strides, int64_t *offset); + /// Returns the memory spcae of the given Unranked MemRef type. MLIR_CAPI_EXPORTED MlirAttribute mlirUnrankedMemrefGetMemorySpace(MlirType type); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 56e895d30..820992de6 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -12,6 +12,8 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" +#include "mlir-c/Support.h" + #include namespace py = pybind11; @@ -618,6 +620,18 @@ class PyMemRefType : public PyConcreteType { return mlirMemRefTypeGetLayout(self); }, "The layout of the MemRef type.") + .def( + "get_strides_and_offset", + [](PyMemRefType &self) -> std::pair, int64_t> { + std::vector strides(mlirShapedTypeGetRank(self)); + int64_t offset; + if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( + self, strides.data(), &offset))) + throw std::runtime_error( + "Failed to extract strides and offset from memref."); + return {strides, offset}; + }, + "The strides and offset of the MemRef type.") .def_property_readonly( "affine_map", [](PyMemRefType &self) -> PyAffineMap { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 6e645188d..18c9414c5 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -9,12 +9,16 @@ #include "mlir-c/BuiltinTypes.h" #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" +#include "mlir/Support/LogicalResult.h" + +#include using namespace mlir; @@ -426,6 +430,18 @@ MlirAttribute mlirMemRefTypeGetMemorySpace(MlirType type) { return wrap(llvm::cast(unwrap(type)).getMemorySpace()); } +MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, + int64_t *strides, + int64_t *offset) { + MemRefType memrefType = llvm::cast(unwrap(type)); + SmallVector strides_; + if (failed(getStridesAndOffset(memrefType, strides_, *offset))) + return mlirLogicalResultFailure(); + + (void)std::copy(strides_.begin(), strides_.end(), strides); + return mlirLogicalResultSuccess(); +} + MlirTypeID mlirUnrankedMemRefTypeGetTypeID() { return wrap(UnrankedMemRefType::getTypeID()); } diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 1685124fb..3af3b5ce7 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -2,16 +2,30 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -# Provide a convenient name for sub-packages to resolve the main C-extension -# with a relative import. -from .._mlir_libs import _mlir as _cext from typing import ( + List as _List, + Optional as _Optional, Sequence as _Sequence, + Tuple as _Tuple, Type as _Type, TypeVar as _TypeVar, Union as _Union, ) +from .._mlir_libs import _mlir as _cext +from ..ir import ( + ArrayAttr, + Attribute, + BoolAttr, + DenseI64ArrayAttr, + IntegerAttr, + IntegerType, + OpView, + Operation, + ShapedType, + Value, +) + __all__ = [ "equally_sized_accessor", "get_default_loc_context", @@ -138,3 +152,157 @@ def get_op_result_or_op_results( ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] + +StaticIntLike = _Union[int, IntegerAttr] +ValueLike = _Union[Operation, OpView, Value] +MixedInt = _Union[StaticIntLike, ValueLike] + +IntOrAttrList = _Sequence[_Union[IntegerAttr, int]] +OptionalIntList = _Optional[_Union[ArrayAttr, IntOrAttrList]] + +BoolOrAttrList = _Sequence[_Union[BoolAttr, bool]] +OptionalBoolList = _Optional[_Union[ArrayAttr, BoolOrAttrList]] + +MixedValues = _Union[_Sequence[_Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] + +DynamicIndexList = _Sequence[_Union[MixedInt, _Sequence[MixedInt]]] + + +def _dispatch_dynamic_index_list( + indices: _Union[DynamicIndexList, ArrayAttr], +) -> _Tuple[_List[ValueLike], _Union[_List[int], ArrayAttr], _List[bool]]: + """Dispatches a list of indices to the appropriate form. + + This is similar to the custom `DynamicIndexList` directive upstream: + provided indices may be in the form of dynamic SSA values or static values, + and they may be scalable (i.e., as a singleton list) or not. This function + dispatches each index into its respective form. It also extracts the SSA + values and static indices from various similar structures, respectively. + """ + dynamic_indices = [] + static_indices = [ShapedType.get_dynamic_size()] * len(indices) + scalable_indices = [False] * len(indices) + + # ArrayAttr: Extract index values. + if isinstance(indices, ArrayAttr): + indices = [idx for idx in indices] + + def process_nonscalable_index(i, index): + """Processes any form of non-scalable index. + + Returns False if the given index was scalable and thus remains + unprocessed; True otherwise. + """ + if isinstance(index, int): + static_indices[i] = index + elif isinstance(index, IntegerAttr): + static_indices[i] = index.value # pytype: disable=attribute-error + elif isinstance(index, (Operation, Value, OpView)): + dynamic_indices.append(index) + else: + return False + return True + + # Process each index at a time. + for i, index in enumerate(indices): + if not process_nonscalable_index(i, index): + # If it wasn't processed, it must be a scalable index, which is + # provided as a _Sequence of one value, so extract and process that. + scalable_indices[i] = True + assert len(index) == 1 + ret = process_nonscalable_index(i, index[0]) + assert ret + + return dynamic_indices, static_indices, scalable_indices + + +# Dispatches `MixedValues` that all represents integers in various forms into +# the following three categories: +# - `dynamic_values`: a list of `Value`s, potentially from op results; +# - `packed_values`: a value handle, potentially from an op result, associated +# to one or more payload operations of integer type; +# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python +# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. +# The input is in the form for `packed_values`, only that result is set and the +# other two are empty. Otherwise, the input can be a mix of the other two forms, +# and for each dynamic value, a special value is added to the `static_values`. +def _dispatch_mixed_values( + values: MixedValues, +) -> _Tuple[_List[Value], _Union[Operation, Value, OpView], DenseI64ArrayAttr]: + dynamic_values = [] + packed_values = None + static_values = None + if isinstance(values, ArrayAttr): + static_values = values + elif isinstance(values, (Operation, Value, OpView)): + packed_values = values + else: + static_values = [] + for size in values or []: + if isinstance(size, int): + static_values.append(size) + else: + static_values.append(ShapedType.get_dynamic_size()) + dynamic_values.append(size) + static_values = DenseI64ArrayAttr.get(static_values) + + return (dynamic_values, packed_values, static_values) + + +def _get_value_or_attribute_value( + value_or_attr: _Union[any, Attribute, ArrayAttr] +) -> any: + if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): + return value_or_attr.value + if isinstance(value_or_attr, ArrayAttr): + return _get_value_list(value_or_attr) + return value_or_attr + + +def _get_value_list( + sequence_or_array_attr: _Union[_Sequence[any], ArrayAttr] +) -> _Sequence[any]: + return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] + + +def _get_int_array_attr( + values: _Optional[_Union[ArrayAttr, IntOrAttrList]] +) -> ArrayAttr: + if values is None: + return None + + # Turn into a Python list of Python ints. + values = _get_value_list(values) + + # Make an ArrayAttr of IntegerAttrs out of it. + return ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] + ) + + +def _get_int_array_array_attr( + values: _Optional[_Union[ArrayAttr, _Sequence[_Union[ArrayAttr, IntOrAttrList]]]] +) -> ArrayAttr: + """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. + + The input has to be a collection of a collection of integers, where any + Python _Sequence and ArrayAttr are admissible collections and Python ints and + any IntegerAttr are admissible integers. Both levels of collections are + turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. + If the input is None, an empty ArrayAttr is returned. + """ + if values is None: + return None + + # Make sure the outer level is a list. + values = _get_value_list(values) + + # The inner level is now either invalid or a mixed sequence of ArrayAttrs and + # Sequences. Make sure the nested values are all lists. + values = [_get_value_list(nested) for nested in values] + + # Turn each nested list into an ArrayAttr. + values = [_get_int_array_attr(nested) for nested in values] + + # Turn the outer list into an ArrayAttr. + return ArrayAttr.get(values) diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index 3afb6a70c..a3d783415 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -1,5 +1,135 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +import operator +from itertools import accumulate +from typing import Optional from ._memref_ops_gen import * +from ._ods_common import _dispatch_mixed_values, MixedValues +from .arith import ConstantOp, _is_integer_like_type +from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType + + +def _is_constant_int_like(i): + return ( + isinstance(i, Value) + and isinstance(i.owner.opview, ConstantOp) + and _is_integer_like_type(i.type) + ) + + +def _is_static_int_like(i): + return ( + isinstance(i, int) and not ShapedType.is_dynamic_size(i) + ) or _is_constant_int_like(i) + + +def _infer_memref_subview_result_type( + source_memref_type, offsets, static_sizes, static_strides +): + source_strides, source_offset = source_memref_type.get_strides_and_offset() + # "canonicalize" from tuple|list -> list + offsets, static_sizes, static_strides, source_strides = map( + list, (offsets, static_sizes, static_strides, source_strides) + ) + + if not all( + all(_is_static_int_like(i) for i in s) + for s in [ + static_sizes, + static_strides, + source_strides, + ] + ): + raise ValueError( + "Only inferring from python or mlir integer constant is supported." + ) + + for s in [offsets, static_sizes, static_strides]: + for idx, i in enumerate(s): + if _is_constant_int_like(i): + s[idx] = i.owner.opview.literal_value + + if any(not _is_static_int_like(i) for i in offsets + [source_offset]): + target_offset = ShapedType.get_dynamic_size() + else: + target_offset = source_offset + for offset, target_stride in zip(offsets, source_strides): + target_offset += offset * target_stride + + target_strides = [] + for source_stride, static_stride in zip(source_strides, static_strides): + target_strides.append(source_stride * static_stride) + + # If default striding then no need to complicate things for downstream ops (e.g., expand_shape). + default_strides = list(accumulate(static_sizes[1:][::-1], operator.mul))[::-1] + [1] + if target_strides == default_strides and target_offset == 0: + layout = None + else: + layout = StridedLayoutAttr.get(target_offset, target_strides) + return ( + offsets, + static_sizes, + static_strides, + MemRefType.get( + static_sizes, + source_memref_type.element_type, + layout, + source_memref_type.memory_space, + ), + ) + + +_generated_subview = subview + + +def subview( + source: Value, + offsets: MixedValues, + sizes: MixedValues, + strides: MixedValues, + *, + result_type: Optional[MemRefType] = None, + loc=None, + ip=None, +): + if offsets is None: + offsets = [] + if sizes is None: + sizes = [] + if strides is None: + strides = [] + source_strides, source_offset = source.type.get_strides_and_offset() + if result_type is None and all( + all(_is_static_int_like(i) for i in s) for s in [sizes, strides, source_strides] + ): + # If any are arith.constant results then this will canonicalize to python int + # (which can then be used to fully specify the subview). + ( + offsets, + sizes, + strides, + result_type, + ) = _infer_memref_subview_result_type(source.type, offsets, sizes, strides) + elif result_type is None: + raise ValueError( + "mixed static/dynamic offset/sizes/strides requires explicit result type." + ) + + offsets, _packed_offsets, static_offsets = _dispatch_mixed_values(offsets) + sizes, _packed_sizes, static_sizes = _dispatch_mixed_values(sizes) + strides, _packed_strides, static_strides = _dispatch_mixed_values(strides) + + return _generated_subview( + result_type, + source, + offsets, + sizes, + strides, + static_offsets, + static_sizes, + static_strides, + loc=loc, + ip=ip, + ) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 284c93823..d7b41c0bd 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -9,163 +9,24 @@ try: from ...ir import * from ...dialects import transform - from .._ods_common import _cext as _ods_cext + from .._ods_common import ( + DynamicIndexList, + IntOrAttrList, + MixedValues, + OptionalBoolList, + OptionalIntList, + _cext as _ods_cext, + _dispatch_dynamic_index_list, + _dispatch_mixed_values, + _get_int_array_array_attr, + _get_int_array_attr, + _get_value_list, + _get_value_or_attribute_value, + ) except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import List, Optional, Sequence, Tuple, Union, overload - -StaticIntLike = Union[int, IntegerAttr] -ValueLike = Union[Operation, OpView, Value] -MixedInt = Union[StaticIntLike, ValueLike] - -IntOrAttrList = Sequence[Union[IntegerAttr, int]] -OptionalIntList = Optional[Union[ArrayAttr, IntOrAttrList]] - -BoolOrAttrList = Sequence[Union[BoolAttr, bool]] -OptionalBoolList = Optional[Union[ArrayAttr, BoolOrAttrList]] - -MixedValues = Union[Sequence[Union[StaticIntLike, ValueLike]], ArrayAttr, ValueLike] - -DynamicIndexList = Sequence[Union[MixedInt, Sequence[MixedInt]]] - - -def _dispatch_dynamic_index_list( - indices: Union[DynamicIndexList, ArrayAttr], -) -> Tuple[List[ValueLike], Union[List[int], ArrayAttr], List[bool]]: - """Dispatches a list of indices to the appropriate form. - - This is similar to the custom `DynamicIndexList` directive upstream: - provided indices may be in the form of dynamic SSA values or static values, - and they may be scalable (i.e., as a singleton list) or not. This function - dispatches each index into its respective form. It also extracts the SSA - values and static indices from various similar structures, respectively. - """ - dynamic_indices = [] - static_indices = [ShapedType.get_dynamic_size()] * len(indices) - scalable_indices = [False] * len(indices) - - # ArrayAttr: Extract index values. - if isinstance(indices, ArrayAttr): - indices = [idx for idx in indices] - - def process_nonscalable_index(i, index): - """Processes any form of non-scalable index. - - Returns False if the given index was scalable and thus remains - unprocessed; True otherwise. - """ - if isinstance(index, int): - static_indices[i] = index - elif isinstance(index, IntegerAttr): - static_indices[i] = index.value # pytype: disable=attribute-error - elif isinstance(index, (Operation, Value, OpView)): - dynamic_indices.append(index) - else: - return False - return True - - # Process each index at a time. - for i, index in enumerate(indices): - if not process_nonscalable_index(i, index): - # If it wasn't processed, it must be a scalable index, which is - # provided as a Sequence of one value, so extract and process that. - scalable_indices[i] = True - assert len(index) == 1 - ret = process_nonscalable_index(i, index[0]) - assert ret - - return dynamic_indices, static_indices, scalable_indices - - -# Dispatches `MixedValues` that all represents integers in various forms into -# the following three categories: -# - `dynamic_values`: a list of `Value`s, potentially from op results; -# - `packed_values`: a value handle, potentially from an op result, associated -# to one or more payload operations of integer type; -# - `static_values`: an `ArrayAttr` of `i64`s with static values, from Python -# `int`s, from `IntegerAttr`s, or from an `ArrayAttr`. -# The input is in the form for `packed_values`, only that result is set and the -# other two are empty. Otherwise, the input can be a mix of the other two forms, -# and for each dynamic value, a special value is added to the `static_values`. -def _dispatch_mixed_values( - values: MixedValues, -) -> Tuple[List[Value], Union[Operation, Value, OpView], DenseI64ArrayAttr]: - dynamic_values = [] - packed_values = None - static_values = None - if isinstance(values, ArrayAttr): - static_values = values - elif isinstance(values, (Operation, Value, OpView)): - packed_values = values - else: - static_values = [] - for size in values or []: - if isinstance(size, int): - static_values.append(size) - else: - static_values.append(ShapedType.get_dynamic_size()) - dynamic_values.append(size) - static_values = DenseI64ArrayAttr.get(static_values) - - return (dynamic_values, packed_values, static_values) - - -def _get_value_or_attribute_value( - value_or_attr: Union[any, Attribute, ArrayAttr] -) -> any: - if isinstance(value_or_attr, Attribute) and hasattr(value_or_attr, "value"): - return value_or_attr.value - if isinstance(value_or_attr, ArrayAttr): - return _get_value_list(value_or_attr) - return value_or_attr - - -def _get_value_list( - sequence_or_array_attr: Union[Sequence[any], ArrayAttr] -) -> Sequence[any]: - return [_get_value_or_attribute_value(v) for v in sequence_or_array_attr] - - -def _get_int_array_attr(values: Optional[Union[ArrayAttr, IntOrAttrList]]) -> ArrayAttr: - if values is None: - return None - - # Turn into a Python list of Python ints. - values = _get_value_list(values) - - # Make an ArrayAttr of IntegerAttrs out of it. - return ArrayAttr.get( - [IntegerAttr.get(IntegerType.get_signless(64), v) for v in values] - ) - - -def _get_int_array_array_attr( - values: Optional[Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]]] -) -> ArrayAttr: - """Creates an ArrayAttr of ArrayAttrs of IntegerAttrs. - - The input has to be a collection of collection of integers, where any - Python Sequence and ArrayAttr are admissible collections and Python ints and - any IntegerAttr are admissible integers. Both levels of collections are - turned into ArrayAttr; the inner level is turned into IntegerAttrs of i64s. - If the input is None, an empty ArrayAttr is returned. - """ - if values is None: - return None - - # Make sure the outer level is a list. - values = _get_value_list(values) - - # The inner level is now either invalid or a mixed sequence of ArrayAttrs and - # Sequences. Make sure the nested values are all lists. - values = [_get_value_list(nested) for nested in values] - - # Turn each nested list into an ArrayAttr. - values = [_get_int_array_attr(nested) for nested in values] - - # Turn the outer list into an ArrayAttr. - return ArrayAttr.get(values) +from typing import List, Optional, Sequence, Union, overload @_ods_cext.register_operation(_Dialect, replace=True) From 2a89e738bf30c00d8abc7ad969a0cdddcdc67232 Mon Sep 17 00:00:00 2001 From: Yinying Li <107574043+yinying-lisa-li@users.noreply.github.com> Date: Mon, 5 Feb 2024 22:00:52 +0000 Subject: [PATCH 668/915] [mlir][sparse] Change LevelType enum to 64 bit (#80501) 1. C++ enum is set through enum class LevelType : uint_64. 2. C enum is set through typedef uint_64 level_type. It is due to the limitations in Windows build: setting enum width to ui64 is not supported in C. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 8 +++++--- mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 2 +- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 41d024db0..42d8400cb 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -25,7 +25,9 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// These correspond to SparseTensorEncodingAttr::LevelType in the C++ API. /// If updating, keep them in sync and update the static_assert in the impl /// file. -enum MlirSparseTensorLevelType { +typedef uint64_t MlirSparseTensorLevelType; + +enum MlirBaseSparseTensorLevelType { MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00 MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00 MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 @@ -53,7 +55,7 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); /// Creates a `sparse_tensor.encoding` attribute with the given parameters. MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, - enum MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, + MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, MlirAffineMap lvlTodim, int posWidth, int crdWidth); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. @@ -61,7 +63,7 @@ MLIR_CAPI_EXPORTED intptr_t mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); /// Returns a specified level-type of the `sparse_tensor.encoding` attribute. -MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelType +MLIR_CAPI_EXPORTED MlirSparseTensorLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding` diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 8706c5239..698367a1a 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -23,7 +23,7 @@ using namespace mlir; using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { - py::enum_(m, "LevelType", py::module_local()) + py::enum_(m, "LevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) From f1b4531f6ebec1185c4005d074afb6ede128f860 Mon Sep 17 00:00:00 2001 From: Yinying Li <107574043+yinying-lisa-li@users.noreply.github.com> Date: Thu, 8 Feb 2024 19:38:42 +0000 Subject: [PATCH 669/915] [mlir][sparse] Implement parsing n out of m (#79935) 1. Add parsing methods for block[n, m]. 2. Encode n and m with the newly extended 64-bit LevelType enum. 3. Update 2:4 methods names/comments to n:m. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 28 +++++------ .../Bindings/Python/DialectSparseTensor.cpp | 2 +- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 49 ++++++++++++------- 3 files changed, 45 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 42d8400cb..2c71b0008 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -28,20 +28,20 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); typedef uint64_t MlirSparseTensorLevelType; enum MlirBaseSparseTensorLevelType { - MLIR_SPARSE_TENSOR_LEVEL_DENSE = 4, // 0b00001_00 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 8, // 0b00010_00 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 9, // 0b00010_01 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 10, // 0b00010_10 - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 11, // 0b00010_11 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 16, // 0b00100_00 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 17, // 0b00100_01 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 18, // 0b00100_10 - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 19, // 0b00100_11 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 32, // 0b01000_00 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 33, // 0b01000_01 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 34, // 0b01000_10 - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 35, // 0b01000_11 - MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR = 64, // 0b10000_00 + MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003, + MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000, }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 698367a1a..607534c61 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -25,7 +25,7 @@ using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "LevelType", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) - .value("compressed24", MLIR_SPARSE_TENSOR_LEVEL_TWO_OUT_OF_FOUR) + .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index e4534ad13..a34b9a29b 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -20,25 +20,36 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, mlir::sparse_tensor::SparseTensorDialect) // Ensure the C-API enums are int-castable to C++ equivalents. -static_assert(static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == - static_cast(LevelType::Dense) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == - static_cast(LevelType::Compressed) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == - static_cast(LevelType::CompressedNu) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == - static_cast(LevelType::CompressedNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == - static_cast(LevelType::CompressedNuNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == - static_cast(LevelType::Singleton) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == - static_cast(LevelType::SingletonNu) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == - static_cast(LevelType::SingletonNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == - static_cast(LevelType::SingletonNuNo), - "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); +static_assert( + static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == + static_cast(LevelType::Dense) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == + static_cast(LevelType::Compressed) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == + static_cast(LevelType::CompressedNu) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == + static_cast(LevelType::CompressedNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == + static_cast(LevelType::CompressedNuNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == + static_cast(LevelType::Singleton) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == + static_cast(LevelType::SingletonNu) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == + static_cast(LevelType::SingletonNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == + static_cast(LevelType::SingletonNuNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == + static_cast(LevelType::LooseCompressed) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) == + static_cast(LevelType::LooseCompressedNu) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) == + static_cast(LevelType::LooseCompressedNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) == + static_cast(LevelType::LooseCompressedNuNo) && + static_cast(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == + static_cast(LevelType::NOutOfM), + "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); From b76dfead5cd822382fb756aab717a50614dc8b98 Mon Sep 17 00:00:00 2001 From: John Demme Date: Thu, 8 Feb 2024 11:39:06 -0800 Subject: [PATCH 670/915] [MLIR][Python] Add method for getting the live operation objects (#78663) Currently, a method exists to get the count of the operation objects which are still alive. This helps for sanity checking, but isn't terribly useful for debugging. This new method returns the actual operation objects which are still alive. This allows Python code like the following: ``` gc.collect() live_ops = ir.Context.current._get_live_operation_objects() for op in live_ops: print(f"Warning: {op} is still live. Referrers:") for referrer in gc.get_referrers(op)[0]: print(f" {referrer}") ``` --- mlir/lib/Bindings/Python/IRCore.cpp | 9 +++++++++ mlir/lib/Bindings/Python/IRModule.h | 3 +++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 1 + 3 files changed, 13 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5412c3dec..8a7951dc2 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -636,6 +636,13 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +std::vector PyMlirContext::getLiveOperationObjects() { + std::vector liveObjects; + for (auto &entry : liveOperations) + liveObjects.push_back(entry.second.second); + return liveObjects; +} + size_t PyMlirContext::clearLiveOperations() { for (auto &op : liveOperations) op.second.second->setInvalid(); @@ -2546,6 +2553,8 @@ void mlir::python::populateIRCore(py::module &m) { return ref.releaseObject(); }) .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) + .def("_get_live_operation_objects", + &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 79b7e0c96..48f39c939 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -201,6 +201,9 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); + /// Get a list of Python objects which are still in the live context map. + std::vector getLiveOperationObjects(); + /// Gets the count of live operations associated with this context. /// Used for testing. size_t getLiveOperationCount(); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 57a85990f..344abb64a 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -985,6 +985,7 @@ class Context: def _get_context_again(self) -> Context: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... + def _get_live_operation_objects(self) -> List[Operation]: ... def append_dialect_registry(self, registry: DialectRegistry) -> None: ... def attach_diagnostic_handler( self, callback: Callable[[Diagnostic], bool] From 0a07d1b158d8223aa0daf1042349f43c0ec692ad Mon Sep 17 00:00:00 2001 From: Yinying Li <107574043+yinying-lisa-li@users.noreply.github.com> Date: Fri, 9 Feb 2024 14:34:36 -0500 Subject: [PATCH 671/915] [mlir][sparse] Add more tests and verification for n:m (#81186) 1. Add python test for n out of m 2. Add more methods for python binding 3. Add verification for n:m and invalid encoding tests 4. Add e2e test for n:m Previous PRs for n:m #80501 #79935 --- mlir/include/mlir-c/Dialect/SparseTensor.h | 10 +++++ .../Bindings/Python/DialectSparseTensor.cpp | 38 ++++++++++++++++++- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 18 +++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 2c71b0008..d549f5ddd 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -84,6 +84,16 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr); MLIR_CAPI_EXPORTED int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr); +MLIR_CAPI_EXPORTED unsigned +mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType); + +MLIR_CAPI_EXPORTED unsigned +mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType); + +MLIR_CAPI_EXPORTED MlirSparseTensorLevelType +mlirSparseTensorEncodingAttrBuildLvlType( + enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 607534c61..74f4d2413 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -60,6 +60,15 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") + .def_classmethod( + "build_level_type", + [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n, + unsigned m) { + return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m); + }, + py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0, + py::arg("m") = 0, + "Builds a sparse_tensor.encoding.level_type from parameters.") .def_property_readonly( "lvl_types", [](MlirAttribute self) { @@ -89,7 +98,34 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { .def_property_readonly("pos_width", mlirSparseTensorEncodingAttrGetPosWidth) .def_property_readonly("crd_width", - mlirSparseTensorEncodingAttrGetCrdWidth); + mlirSparseTensorEncodingAttrGetCrdWidth) + .def_property_readonly( + "structured_n", + [](MlirAttribute self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredN( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }) + .def_property_readonly( + "structured_m", + [](MlirAttribute self) -> unsigned { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + return mlirSparseTensorEncodingAttrGetStructuredM( + mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); + }) + .def_property_readonly("lvl_types_enum", [](MlirAttribute self) { + const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); + std::vector ret; + ret.reserve(lvlRank); + for (int l = 0; l < lvlRank; l++) { + // Convert level type to 32 bits to ignore n and m for n_out_of_m + // format. + ret.push_back( + static_cast(static_cast( + mlirSparseTensorEncodingAttrGetLvlType(self, l)))); + } + return ret; + }); } PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index a34b9a29b..4e1bd4586 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -94,3 +94,21 @@ int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { return cast(unwrap(attr)).getCrdWidth(); } + +MlirSparseTensorLevelType +mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType, + unsigned n, unsigned m) { + LevelType lt = static_cast(lvlType); + return static_cast(*buildLevelType( + *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m)); +} + +unsigned +mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType) { + return getN(static_cast(lvlType)); +} + +unsigned +mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType) { + return getM(static_cast(lvlType)); +} From 8d05afd24cdd65220ca45a62c7e67056f83e9b29 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 12 Feb 2024 17:35:43 +0000 Subject: [PATCH 672/915] [MLIR][Python] Add missing peel_front argument to LoopPeelOp's extension class (#81424) --- mlir/python/mlir/dialects/transform/loop.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/loop.py b/mlir/python/mlir/dialects/transform/loop.py index 3bdd9ca3b..c4770b1c4 100644 --- a/mlir/python/mlir/dialects/transform/loop.py +++ b/mlir/python/mlir/dialects/transform/loop.py @@ -55,6 +55,7 @@ def __init__( remainder_loop_type: Type, target: Union[Operation, Value], *, + peel_front: Union[bool, BoolAttr] = False, fail_if_already_divisible: Union[bool, BoolAttr] = False, ip=None, loc=None, @@ -63,6 +64,11 @@ def __init__( main_loop_type, remainder_loop_type, _get_op_result_or_value(target), + peel_front=( + peel_front + if isinstance(peel_front, BoolAttr) + else BoolAttr.get(peel_front) + ), fail_if_already_divisible=( fail_if_already_divisible if isinstance(fail_if_already_divisible, BoolAttr) From 013d7498ba589f4e2abaeab44a87fb261703a46d Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Tue, 13 Feb 2024 18:45:22 -0600 Subject: [PATCH 673/915] =?UTF-8?q?[mlir][sparse][pybind][CAPI]=20remove?= =?UTF-8?q?=20LevelType=20enum=20from=20CAPI,=20constru=E2=80=A6=20(#81682?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ct LevelType from LevelFormat and properties instead. **Rationale** We used to explicitly declare every possible combination between `LevelFormat` and `LevelProperties`, and it now becomes difficult to scale as more properties/level formats are going to be introduced. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 24 ++++---- .../Bindings/Python/DialectSparseTensor.cpp | 49 +++++++-------- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 61 ++++++++++--------- 3 files changed, 65 insertions(+), 69 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index d549f5ddd..898d2f127 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -27,23 +27,19 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor); /// file. typedef uint64_t MlirSparseTensorLevelType; -enum MlirBaseSparseTensorLevelType { +enum MlirSparseTensorLevelFormat { MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000, MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000, - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU = 0x000000020001, - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO = 0x000000020002, - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO = 0x000000020003, MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000, - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU = 0x000000040001, - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO = 0x000000040002, - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO = 0x000000040003, MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000, - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU = 0x000000080001, - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO = 0x000000080002, - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO = 0x000000080003, MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000, }; +enum MlirSparseTensorLevelPropertyNondefault { + MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001, + MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002, +}; + //===----------------------------------------------------------------------===// // SparseTensorEncodingAttr //===----------------------------------------------------------------------===// @@ -66,6 +62,10 @@ mlirSparseTensorEncodingGetLvlRank(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirSparseTensorLevelType mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl); +/// Returns a specified level-format of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED enum MlirSparseTensorLevelFormat +mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl); + /// Returns the dimension-to-level mapping of the `sparse_tensor.encoding` /// attribute. MLIR_CAPI_EXPORTED MlirAffineMap @@ -92,7 +92,9 @@ mlirSparseTensorEncodingAttrGetStructuredM(MlirSparseTensorLevelType lvlType); MLIR_CAPI_EXPORTED MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( - enum MlirBaseSparseTensorLevelType lvlType, unsigned n, unsigned m); + enum MlirSparseTensorLevelFormat lvlFmt, + const enum MlirSparseTensorLevelPropertyNondefault *properties, + unsigned propSize, unsigned n, unsigned m); #ifdef __cplusplus } diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 74f4d2413..171faf9e0 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -23,24 +23,17 @@ using namespace mlir; using namespace mlir::python::adaptors; static void populateDialectSparseTensorSubmodule(const py::module &m) { - py::enum_(m, "LevelType", py::module_local()) + py::enum_(m, "LevelFormat", py::module_local()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) - .value("compressed_nu", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) - .value("compressed_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) - .value("compressed_nu_no", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) - .value("singleton_nu", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) - .value("singleton_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) - .value("singleton_nu_no", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) - .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) - .value("loose_compressed_nu", - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) - .value("loose_compressed_no", - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) - .value("loose_compressed_nu_no", - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO); + .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED); + + py::enum_(m, "LevelProperty", + py::module_local()) + .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED) + .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) @@ -62,12 +55,17 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { "Gets a sparse_tensor.encoding from parameters.") .def_classmethod( "build_level_type", - [](py::object cls, MlirBaseSparseTensorLevelType lvlType, unsigned n, - unsigned m) { - return mlirSparseTensorEncodingAttrBuildLvlType(lvlType, n, m); + [](py::object cls, MlirSparseTensorLevelFormat lvlFmt, + const std::vector + &properties, + unsigned n, unsigned m) { + return mlirSparseTensorEncodingAttrBuildLvlType( + lvlFmt, properties.data(), properties.size(), n, m); }, - py::arg("cls"), py::arg("lvl_type"), py::arg("n") = 0, - py::arg("m") = 0, + py::arg("cls"), py::arg("lvl_fmt"), + py::arg("properties") = + std::vector(), + py::arg("n") = 0, py::arg("m") = 0, "Builds a sparse_tensor.encoding.level_type from parameters.") .def_property_readonly( "lvl_types", @@ -113,17 +111,12 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { return mlirSparseTensorEncodingAttrGetStructuredM( mlirSparseTensorEncodingAttrGetLvlType(self, lvlRank - 1)); }) - .def_property_readonly("lvl_types_enum", [](MlirAttribute self) { + .def_property_readonly("lvl_formats_enum", [](MlirAttribute self) { const int lvlRank = mlirSparseTensorEncodingGetLvlRank(self); - std::vector ret; + std::vector ret; ret.reserve(lvlRank); - for (int l = 0; l < lvlRank; l++) { - // Convert level type to 32 bits to ignore n and m for n_out_of_m - // format. - ret.push_back( - static_cast(static_cast( - mlirSparseTensorEncodingAttrGetLvlType(self, l)))); - } + for (int l = 0; l < lvlRank; l++) + ret.push_back(mlirSparseTensorEncodingAttrGetLvlFmt(self, l)); return ret; }); } diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 4e1bd4586..55af8becb 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -22,34 +22,23 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SparseTensor, sparse_tensor, // Ensure the C-API enums are int-castable to C++ equivalents. static_assert( static_cast(MLIR_SPARSE_TENSOR_LEVEL_DENSE) == - static_cast(LevelType::Dense) && + static_cast(LevelFormat::Dense) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) == - static_cast(LevelType::Compressed) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU) == - static_cast(LevelType::CompressedNu) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NO) == - static_cast(LevelType::CompressedNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED_NU_NO) == - static_cast(LevelType::CompressedNuNo) && + static_cast(LevelFormat::Compressed) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) == - static_cast(LevelType::Singleton) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU) == - static_cast(LevelType::SingletonNu) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NO) == - static_cast(LevelType::SingletonNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_SINGLETON_NU_NO) == - static_cast(LevelType::SingletonNuNo) && + static_cast(LevelFormat::Singleton) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED) == - static_cast(LevelType::LooseCompressed) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU) == - static_cast(LevelType::LooseCompressedNu) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NO) == - static_cast(LevelType::LooseCompressedNo) && - static_cast(MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED_NU_NO) == - static_cast(LevelType::LooseCompressedNuNo) && + static_cast(LevelFormat::LooseCompressed) && static_cast(MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) == - static_cast(LevelType::NOutOfM), - "MlirSparseTensorLevelType (C-API) and LevelType (C++) mismatch"); + static_cast(LevelFormat::NOutOfM), + "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); + +static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == + static_cast(LevelPropertyNondefault::Nonordered) && + static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == + static_cast(LevelPropertyNondefault::Nonunique), + "MlirSparseTensorLevelProperty (C-API) and " + "LevelPropertyNondefault (C++) mismatch"); bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); @@ -87,6 +76,13 @@ mlirSparseTensorEncodingAttrGetLvlType(MlirAttribute attr, intptr_t lvl) { cast(unwrap(attr)).getLvlType(lvl)); } +enum MlirSparseTensorLevelFormat +mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { + LevelType lt = + static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); + return static_cast(*getLevelFormat(lt)); +} + int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { return cast(unwrap(attr)).getPosWidth(); } @@ -95,12 +91,17 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { return cast(unwrap(attr)).getCrdWidth(); } -MlirSparseTensorLevelType -mlirSparseTensorEncodingAttrBuildLvlType(MlirBaseSparseTensorLevelType lvlType, - unsigned n, unsigned m) { - LevelType lt = static_cast(lvlType); - return static_cast(*buildLevelType( - *getLevelFormat(lt), isOrderedLT(lt), isUniqueLT(lt), n, m)); +MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( + enum MlirSparseTensorLevelFormat lvlFmt, + const enum MlirSparseTensorLevelPropertyNondefault *properties, + unsigned size, unsigned n, unsigned m) { + + std::vector props; + for (unsigned i = 0; i < size; i++) + props.push_back(static_cast(properties[i])); + + return static_cast( + *buildLevelType(static_cast(lvlFmt), props, n, m)); } unsigned From 27d34fe3b16a0820e7cffbefb0709302dce98468 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Wed, 14 Feb 2024 12:02:49 +0000 Subject: [PATCH 674/915] [MLIR][Python] Added a base class to all builtin floating point types (#81720) This allows to * check if a given ir.Type is a floating point type via isinstance() or issubclass() * get the bitwidth of a floating point type See motivation and discussion in https://discourse.llvm.org/t/add-floattype-to-mlir-python-bindings/76959. --- mlir/include/mlir-c/BuiltinTypes.h | 6 ++++ mlir/lib/Bindings/Python/IRTypes.cpp | 38 +++++++++++++++++------- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 8 +++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 28 +++++++++++------ 4 files changed, 61 insertions(+), 19 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 881b6dad2..99c5e3f46 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -73,6 +73,12 @@ MLIR_CAPI_EXPORTED MlirType mlirIndexTypeGet(MlirContext ctx); // Floating-point types. //===----------------------------------------------------------------------===// +/// Checks whether the given type is a floating-point type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); + +/// Returns the bitwidth of a floating-point type. +MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); + /// Returns the typeID of an Float8E5M2 type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 820992de6..e1e4eb999 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -109,8 +109,22 @@ class PyIndexType : public PyConcreteType { } }; +class PyFloatType : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; + static constexpr const char *pyClassName = "FloatType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_property_readonly( + "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); + } +}; + /// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType : public PyConcreteType { +class PyFloat8E4M3FNType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -130,7 +144,7 @@ class PyFloat8E4M3FNType : public PyConcreteType { }; /// Floating Point Type subclass - Float8M5E2Type. -class PyFloat8E5M2Type : public PyConcreteType { +class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -150,7 +164,8 @@ class PyFloat8E5M2Type : public PyConcreteType { }; /// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType : public PyConcreteType { +class PyFloat8E4M3FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -170,7 +185,8 @@ class PyFloat8E4M3FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType : public PyConcreteType { +class PyFloat8E4M3B11FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -190,7 +206,8 @@ class PyFloat8E4M3B11FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType : public PyConcreteType { +class PyFloat8E5M2FNUZType + : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -210,7 +227,7 @@ class PyFloat8E5M2FNUZType : public PyConcreteType { }; /// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType { +class PyBF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -230,7 +247,7 @@ class PyBF16Type : public PyConcreteType { }; /// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType { +class PyF16Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -250,7 +267,7 @@ class PyF16Type : public PyConcreteType { }; /// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType { +class PyTF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -270,7 +287,7 @@ class PyTF32Type : public PyConcreteType { }; /// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType { +class PyF32Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -290,7 +307,7 @@ class PyF32Type : public PyConcreteType { }; /// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType { +class PyF64Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; static constexpr GetTypeIDFunctionTy getTypeIdFunction = @@ -819,6 +836,7 @@ class PyOpaqueType : public PyConcreteType { void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); + PyFloatType::bind(m); PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 18c9414c5..e1a5d8258 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -78,6 +78,14 @@ MlirType mlirIndexTypeGet(MlirContext ctx) { // Floating-point types. //===----------------------------------------------------------------------===// +bool mlirTypeIsAFloat(MlirType type) { + return llvm::isa(unwrap(type)); +} + +unsigned mlirFloatTypeGetWidth(MlirType type) { + return llvm::cast(unwrap(type)).getWidth(); +} + MlirTypeID mlirFloat8E5M2TypeGetTypeID() { return wrap(Float8E5M2Type::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 344abb64a..586bf7f8e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -1442,7 +1442,17 @@ class DictAttr(Attribute): @property def typeid(self) -> TypeID: ... -class F16Type(Type): +class FloatType(Type): + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def width(self) -> int: + """ + Returns the width of the floating-point type. + """ + +class F16Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F16Type: @@ -1455,7 +1465,7 @@ class F16Type(Type): @property def typeid(self) -> TypeID: ... -class F32Type(Type): +class F32Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F32Type: @@ -1468,7 +1478,7 @@ class F32Type(Type): @property def typeid(self) -> TypeID: ... -class F64Type(Type): +class F64Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> F64Type: @@ -1502,7 +1512,7 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ -class Float8E4M3B11FNUZType(Type): +class Float8E4M3B11FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: @@ -1515,7 +1525,7 @@ class Float8E4M3B11FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E4M3FNType(Type): +class Float8E4M3FNType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNType: @@ -1528,7 +1538,7 @@ class Float8E4M3FNType(Type): @property def typeid(self) -> TypeID: ... -class Float8E4M3FNUZType(Type): +class Float8E4M3FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: @@ -1541,7 +1551,7 @@ class Float8E4M3FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E5M2FNUZType(Type): +class Float8E5M2FNUZType(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: @@ -1554,7 +1564,7 @@ class Float8E5M2FNUZType(Type): @property def typeid(self) -> TypeID: ... -class Float8E5M2Type(Type): +class Float8E5M2Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2Type: @@ -1601,7 +1611,7 @@ class FloatAttr(Attribute): Returns the value of the float attribute """ -class FloatTF32Type(Type): +class FloatTF32Type(FloatType): static_typeid: ClassVar[TypeID] # value = @staticmethod def get(context: Optional[Context] = None) -> FloatTF32Type: From 11f54769748ca97039a13ef330cdc946339c2f26 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 14 Feb 2024 15:03:04 +0100 Subject: [PATCH 675/915] [mlir][python] expose LLVMStructType API (#81672) Expose the API for constructing and inspecting StructTypes from the LLVM dialect. Separate constructor methods are used instead of overloads for better readability, similarly to IntegerType. --- mlir/include/mlir-c/Dialect/LLVM.h | 61 +++++++++- mlir/lib/Bindings/Python/DialectLLVM.cpp | 145 +++++++++++++++++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 68 ++++++++++- mlir/python/CMakeLists.txt | 13 ++ mlir/python/mlir/dialects/llvm.py | 1 + 5 files changed, 286 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectLLVM.cpp diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 72701a822..ac216b01f 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -34,11 +34,70 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MlirType const *argumentTypes, bool isVarArg); -/// Creates an LLVM literal (unnamed) struct type. +/// Returns `true` if the type is an LLVM dialect struct type. +MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); + +/// Returns `true` if the type is a literal (unnamed) LLVM struct type. +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsLiteral(MlirType type); + +/// Returns the number of fields in the struct. Asserts if the struct is opaque +/// or not yet initialized. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type); + +/// Returns the `positions`-th field of the struct. Asserts if the struct is +/// opaque, not yet initialized or if the position is out of range. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeGetElementType(MlirType type, + intptr_t position); + +/// Returns `true` if the struct is packed. +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsPacked(MlirType type); + +/// Returns the identifier of the identified struct. Asserts that the struct is +/// identified, i.e., not literal. +MLIR_CAPI_EXPORTED MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type); + +/// Returns `true` is the struct is explicitly opaque (will not have a body) or +/// uninitialized (will eventually have a body). +MLIR_CAPI_EXPORTED bool mlirLLVMStructTypeIsOpaque(MlirType type); + +/// Creates an LLVM literal (unnamed) struct type. This may assert if the fields +/// have types not compatible with the LLVM dialect. For a graceful failure, use +/// the checked version. MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, MlirType const *fieldTypes, bool isPacked); +/// Creates an LLVM literal (unnamed) struct type if possible. Emits a +/// diagnostic at the given location and returns null otherwise. +MLIR_CAPI_EXPORTED MlirType +mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +/// Creates an LLVM identified struct type with no body. If a struct type with +/// this name already exists in the context, returns that type. Use +/// mlirLLVMStructTypeIdentifiedNewGet to create a fresh struct type, +/// potentially renaming it. The body should be set separatelty by calling +/// mlirLLVMStructTypeSetBody, if it isn't set already. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, + MlirStringRef name); + +/// Creates an LLVM identified struct type with no body and a name starting with +/// the given prefix. If a struct with the exact name as the given prefix +/// already exists, appends an unspecified suffix to the name so that the name +/// is unique in context. +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeIdentifiedNewGet( + MlirContext ctx, MlirStringRef name, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + +MLIR_CAPI_EXPORTED MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, + MlirStringRef name); + +/// Sets the body of the identified struct if it hasn't been set yet. Returns +/// whether the operation was successful. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes, + MlirType const *fieldTypes, bool isPacked); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp new file mode 100644 index 000000000..780f5eacf --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -0,0 +1,145 @@ +//===- DialectLLVM.cpp - Pybind module for LLVM dialect API support -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + StringRef(message.data, message.length); + }; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + +void populateDialectLLVMSubmodule(const pybind11::module &m) { + auto llvmStructType = + mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); + + llvmStructType.def_classmethod( + "get_literal", + [](py::object cls, const std::vector &elements, bool packed, + MlirLocation loc) { + CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); + + MlirType type = mlirLLVMStructTypeLiteralGetChecked( + loc, elements.size(), elements.data(), packed); + if (mlirTypeIsNull(type)) { + throw py::value_error(scope.takeMessage()); + } + return cls(type); + }, + py::arg("cls"), py::arg("elements"), py::kw_only(), + py::arg("packed") = false, py::arg("loc") = py::none()); + + llvmStructType.def_classmethod( + "get_identified", + [](py::object cls, const std::string &name, MlirContext context) { + return cls(mlirLLVMStructTypeIdentifiedGet( + context, mlirStringRefCreate(name.data(), name.size()))); + }, + py::arg("cls"), py::arg("name"), py::kw_only(), + py::arg("context") = py::none()); + + llvmStructType.def_classmethod( + "get_opaque", + [](py::object cls, const std::string &name, MlirContext context) { + return cls(mlirLLVMStructTypeOpaqueGet( + context, mlirStringRefCreate(name.data(), name.size()))); + }, + py::arg("cls"), py::arg("name"), py::arg("context") = py::none()); + + llvmStructType.def( + "set_body", + [](MlirType self, const std::vector &elements, bool packed) { + MlirLogicalResult result = mlirLLVMStructTypeSetBody( + self, elements.size(), elements.data(), packed); + if (!mlirLogicalResultIsSuccess(result)) { + throw py::value_error( + "Struct body already set to different content."); + } + }, + py::arg("elements"), py::kw_only(), py::arg("packed") = false); + + llvmStructType.def_classmethod( + "new_identified", + [](py::object cls, const std::string &name, + const std::vector &elements, bool packed, MlirContext ctx) { + return cls(mlirLLVMStructTypeIdentifiedNewGet( + ctx, mlirStringRefCreate(name.data(), name.length()), + elements.size(), elements.data(), packed)); + }, + py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(), + py::arg("packed") = false, py::arg("context") = py::none()); + + llvmStructType.def_property_readonly( + "name", [](MlirType type) -> std::optional { + if (mlirLLVMStructTypeIsLiteral(type)) + return std::nullopt; + + MlirStringRef stringRef = mlirLLVMStructTypeGetIdentifier(type); + return StringRef(stringRef.data, stringRef.length).str(); + }); + + llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object { + // Don't crash in absence of a body. + if (mlirLLVMStructTypeIsOpaque(type)) + return py::none(); + + py::list body; + for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e; + ++i) { + body.append(mlirLLVMStructTypeGetElementType(type, i)); + } + return body; + }); + + llvmStructType.def_property_readonly( + "packed", [](MlirType type) { return mlirLLVMStructTypeIsPacked(type); }); + + llvmStructType.def_property_readonly( + "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); }); +} + +PYBIND11_MODULE(_mlirDialectsLLVM, m) { + m.doc() = "MLIR LLVM Dialect"; + + populateDialectLLVMSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index b4405f7aa..642018a81 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -36,11 +36,77 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); } +bool mlirTypeIsALLVMStructType(MlirType type) { + return isa(unwrap(type)); +} + +bool mlirLLVMStructTypeIsLiteral(MlirType type) { + return !cast(unwrap(type)).isIdentified(); +} + +intptr_t mlirLLVMStructTypeGetNumElementTypes(MlirType type) { + return cast(unwrap(type)).getBody().size(); +} + +MlirType mlirLLVMStructTypeGetElementType(MlirType type, intptr_t position) { + return wrap(cast(unwrap(type)).getBody()[position]); +} + +bool mlirLLVMStructTypeIsPacked(MlirType type) { + return cast(unwrap(type)).isPacked(); +} + +MlirStringRef mlirLLVMStructTypeGetIdentifier(MlirType type) { + return wrap(cast(unwrap(type)).getName()); +} + +bool mlirLLVMStructTypeIsOpaque(MlirType type) { + return cast(unwrap(type)).isOpaque(); +} + MlirType mlirLLVMStructTypeLiteralGet(MlirContext ctx, intptr_t nFieldTypes, MlirType const *fieldTypes, bool isPacked) { - SmallVector fieldStorage; + SmallVector fieldStorage; return wrap(LLVMStructType::getLiteral( unwrap(ctx), unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked)); } + +MlirType mlirLLVMStructTypeLiteralGetChecked(MlirLocation loc, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fieldStorage; + return wrap(LLVMStructType::getLiteralChecked( + [loc]() { return emitError(unwrap(loc)); }, unwrap(loc)->getContext(), + unwrapList(nFieldTypes, fieldTypes, fieldStorage), isPacked)); +} + +MlirType mlirLLVMStructTypeOpaqueGet(MlirContext ctx, MlirStringRef name) { + return wrap(LLVMStructType::getOpaque(unwrap(name), unwrap(ctx))); +} + +MlirType mlirLLVMStructTypeIdentifiedGet(MlirContext ctx, MlirStringRef name) { + return wrap(LLVMStructType::getIdentified(unwrap(ctx), unwrap(name))); +} + +MlirType mlirLLVMStructTypeIdentifiedNewGet(MlirContext ctx, MlirStringRef name, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fields; + return wrap(LLVMStructType::getNewIdentified( + unwrap(ctx), unwrap(name), unwrapList(nFieldTypes, fieldTypes, fields), + isPacked)); +} + +MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType, + intptr_t nFieldTypes, + MlirType const *fieldTypes, + bool isPacked) { + SmallVector fields; + return wrap( + cast(unwrap(structType)) + .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked)); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 266b86090..ed167afeb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -482,6 +482,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MLIRCAPILinalg ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind + MODULE_NAME _mlirDialectsLLVM + ADD_TO_PARENT MLIRPythonSources.Dialects.llvm + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectLLVM.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPILLVM +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind MODULE_NAME _mlirDialectsQuant ADD_TO_PARENT MLIRPythonSources.Dialects.quant diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py index 77025438c..8aa16e4a2 100644 --- a/mlir/python/mlir/dialects/llvm.py +++ b/mlir/python/mlir/dialects/llvm.py @@ -4,3 +4,4 @@ from ._llvm_ops_gen import * from ._llvm_enum_gen import * +from .._mlir_libs._mlirDialectsLLVM import * From a247ecb3f221281d9d31f5e40572045ecc48ad89 Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Thu, 15 Feb 2024 14:31:03 -0600 Subject: [PATCH 676/915] =?UTF-8?q?[mlir][sparse]=20remove=20LevelType=20e?= =?UTF-8?q?num,=20construct=20LevelType=20from=20LevelF=E2=80=A6=20(#81799?= =?UTF-8?q?)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ormat and properties instead. --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 55af8becb..3ae06f220 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -34,9 +34,9 @@ static_assert( "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == - static_cast(LevelPropertyNondefault::Nonordered) && + static_cast(LevelPropNonDefault::Nonordered) && static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast(LevelPropertyNondefault::Nonunique), + static_cast(LevelPropNonDefault::Nonunique), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { LevelType lt = static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); - return static_cast(*getLevelFormat(lt)); + return static_cast(lt.getLvlFmt()); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( const enum MlirSparseTensorLevelPropertyNondefault *properties, unsigned size, unsigned n, unsigned m) { - std::vector props; + std::vector props; for (unsigned i = 0; i < size; i++) - props.push_back(static_cast(properties[i])); + props.push_back(static_cast(properties[i])); return static_cast( *buildLevelType(static_cast(lvlFmt), props, n, m)); From 37a58918b829427848b737649fe0483992c367de Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 15 Feb 2024 13:26:44 -0800 Subject: [PATCH 677/915] =?UTF-8?q?Revert=20"[mlir][sparse]=20remove=20Lev?= =?UTF-8?q?elType=20enum,=20construct=20LevelType=20from=20LevelF=E2=80=A6?= =?UTF-8?q?"=20(#81923)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reverts llvm/llvm-project#81799 ; this broke the mlir gcc7 bot. --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 3ae06f220..55af8becb 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -34,9 +34,9 @@ static_assert( "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == - static_cast(LevelPropNonDefault::Nonordered) && + static_cast(LevelPropertyNondefault::Nonordered) && static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast(LevelPropNonDefault::Nonunique), + static_cast(LevelPropertyNondefault::Nonunique), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { LevelType lt = static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); - return static_cast(lt.getLvlFmt()); + return static_cast(*getLevelFormat(lt)); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( const enum MlirSparseTensorLevelPropertyNondefault *properties, unsigned size, unsigned n, unsigned m) { - std::vector props; + std::vector props; for (unsigned i = 0; i < size; i++) - props.push_back(static_cast(properties[i])); + props.push_back(static_cast(properties[i])); return static_cast( *buildLevelType(static_cast(lvlFmt), props, n, m)); From fd8bcfe04c1e932c87465c1a9e66bed1216933da Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Thu, 15 Feb 2024 16:48:52 -0600 Subject: [PATCH 678/915] Reapply "[mlir][sparse] remove LevelType enum, construct LevelType from LevelFormat and Properties" (#81923) (#81934) --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 55af8becb..3ae06f220 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -34,9 +34,9 @@ static_assert( "MlirSparseTensorLevelFormat (C-API) and LevelFormat (C++) mismatch"); static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == - static_cast(LevelPropertyNondefault::Nonordered) && + static_cast(LevelPropNonDefault::Nonordered) && static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast(LevelPropertyNondefault::Nonunique), + static_cast(LevelPropNonDefault::Nonunique), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); @@ -80,7 +80,7 @@ enum MlirSparseTensorLevelFormat mlirSparseTensorEncodingAttrGetLvlFmt(MlirAttribute attr, intptr_t lvl) { LevelType lt = static_cast(mlirSparseTensorEncodingAttrGetLvlType(attr, lvl)); - return static_cast(*getLevelFormat(lt)); + return static_cast(lt.getLvlFmt()); } int mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr) { @@ -96,9 +96,9 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( const enum MlirSparseTensorLevelPropertyNondefault *properties, unsigned size, unsigned n, unsigned m) { - std::vector props; + std::vector props; for (unsigned i = 0; i < size; i++) - props.push_back(static_cast(properties[i])); + props.push_back(static_cast(properties[i])); return static_cast( *buildLevelType(static_cast(lvlFmt), props, n, m)); From 7c0c85b0d19840fa3633a9ae2a3ed70af3904b00 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 21 Feb 2024 11:01:00 +0100 Subject: [PATCH 679/915] [mlir] expose transform interpreter to Python (#82365) Transform interpreter functionality can be used standalone without going through the interpreter pass, make it available in Python. --- .../mlir-c/Dialect/Transform/Interpreter.h | 77 ++++++++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 36 ++++++++ mlir/lib/Bindings/Python/DialectLLVM.cpp | 31 ------- mlir/lib/Bindings/Python/IRCore.cpp | 7 ++ mlir/lib/Bindings/Python/IRModule.h | 1 + .../Bindings/Python/TransformInterpreter.cpp | 90 +++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 ++ .../lib/CAPI/Dialect/TransformInterpreter.cpp | 74 +++++++++++++++ mlir/python/CMakeLists.txt | 19 ++++ .../transform/interpreter/__init__.py | 33 +++++++ 10 files changed, 346 insertions(+), 31 deletions(-) create mode 100644 mlir/include/mlir-c/Dialect/Transform/Interpreter.h create mode 100644 mlir/lib/Bindings/Python/TransformInterpreter.cpp create mode 100644 mlir/lib/CAPI/Dialect/TransformInterpreter.cpp create mode 100644 mlir/python/mlir/dialects/transform/interpreter/__init__.py diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h new file mode 100644 index 000000000..00095d504 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h @@ -0,0 +1,77 @@ +//===-- mlir-c/Dialect/Transform/Interpreter.h --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// C interface to the transform dialect interpreter. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirTransformOptions, void); + +#undef DEFINE_C_API_STRUCT + +//----------------------------------------------------------------------------// +// MlirTransformOptions +//----------------------------------------------------------------------------// + +/// Creates a default-initialized transform options object. +MLIR_CAPI_EXPORTED MlirTransformOptions mlirTransformOptionsCreate(void); + +/// Enables or disables expensive checks in transform options. +MLIR_CAPI_EXPORTED void +mlirTransformOptionsEnableExpensiveChecks(MlirTransformOptions transformOptions, + bool enable); + +/// Returns true if expensive checks are enabled in transform options. +MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetExpensiveChecksEnabled( + MlirTransformOptions transformOptions); + +/// Enables or disables the enforcement of the top-level transform op being +/// single in transform options. +MLIR_CAPI_EXPORTED void mlirTransformOptionsEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions, bool enable); + +/// Returns true if the enforcement of the top-level transform op being single +/// is enabled in transform options. +MLIR_CAPI_EXPORTED bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions); + +/// Destroys a transform options object previously created by +/// mlirTransformOptionsCreate. +MLIR_CAPI_EXPORTED void +mlirTransformOptionsDestroy(MlirTransformOptions transformOptions); + +//----------------------------------------------------------------------------// +// Transform interpreter. +//----------------------------------------------------------------------------// + +/// Applies the transformation script starting at the given transform root +/// operation to the given payload operation. The module containing the +/// transform root as well as the transform options should be provided. The +/// transform operation must implement TransformOpInterface and the module must +/// be a ModuleOp. Returns the status of the application. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence( + MlirOperation payload, MlirOperation transformRoot, + MlirOperation transformModule, MlirTransformOptions transformOptions); + +#ifdef __cplusplus +} +#endif diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 66cf20e1c..52f632125 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -23,6 +23,7 @@ #include #include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "llvm/ADT/Twine.h" @@ -569,6 +570,41 @@ class mlir_value_subclass : public pure_subclass { }; } // namespace adaptors + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 780f5eacf..843707751 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Diagnostics.h" #include "mlir-c/Dialect/LLVM.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" @@ -19,36 +18,6 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::adaptors; -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - StringRef(message.data, message.length); - }; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - void populateDialectLLVMSubmodule(const pybind11::module &m) { auto llvmStructType = mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8a7951dc2..734f2f7f3 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -678,6 +678,10 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) { mlirOperationWalk(op.getOperation(), invalidatingCallback, static_cast(&data), MlirWalkPreOrder); } +void PyMlirContext::clearOperationsInside(MlirOperation op) { + PyOperationRef opRef = PyOperation::forOperation(getRef(), op); + clearOperationsInside(opRef->getOperation()); +} size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } @@ -2556,6 +2560,9 @@ void mlir::python::populateIRCore(py::module &m) { .def("_get_live_operation_objects", &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) + .def("_clear_live_operations_inside", + py::overload_cast( + &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 48f39c939..9acfdde25 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -223,6 +223,7 @@ class PyMlirContext { /// Clears all operations nested inside the given op using /// `clearOperation(MlirOperation)`. void clearOperationsInside(PyOperationBase &op); + void clearOperationsInside(MlirOperation op); /// Gets the count of live modules associated with this context. /// Used for testing. diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp new file mode 100644 index 000000000..6517f8c39 --- /dev/null +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -0,0 +1,90 @@ +//===- TransformInterpreter.cpp -------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Pybind classes for the transform dialect interpreter. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform/Interpreter.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +#include +#include + +namespace py = pybind11; + +namespace { +struct PyMlirTransformOptions { + PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; + PyMlirTransformOptions(PyMlirTransformOptions &&other) { + options = other.options; + other.options.ptr = nullptr; + } + PyMlirTransformOptions(const PyMlirTransformOptions &) = delete; + + ~PyMlirTransformOptions() { mlirTransformOptionsDestroy(options); } + + MlirTransformOptions options; +}; +} // namespace + +static void populateTransformInterpreterSubmodule(py::module &m) { + py::class_(m, "TransformOptions", py::module_local()) + .def(py::init()) + .def_property( + "expensive_checks", + [](const PyMlirTransformOptions &self) { + return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); + }, + [](PyMlirTransformOptions &self, bool value) { + mlirTransformOptionsEnableExpensiveChecks(self.options, value); + }) + .def_property( + "enforce_single_top_level_transform_op", + [](const PyMlirTransformOptions &self) { + return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + self.options); + }, + [](PyMlirTransformOptions &self, bool value) { + mlirTransformOptionsEnforceSingleTopLevelTransformOp(self.options, + value); + }); + + m.def( + "apply_named_sequence", + [](MlirOperation payloadRoot, MlirOperation transformRoot, + MlirOperation transformModule, const PyMlirTransformOptions &options) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(transformRoot)); + + // Calling back into Python to invalidate everything under the payload + // root. This is awkward, but we don't have access to PyMlirContext + // object here otherwise. + py::object obj = py::cast(payloadRoot); + obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); + + MlirLogicalResult result = mlirTransformApplyNamedSequence( + payloadRoot, transformRoot, transformModule, options.options); + if (mlirLogicalResultIsSuccess(result)) + return; + + throw py::value_error( + "Failed to apply named transform sequence.\nDiagnostic message " + + scope.takeMessage()); + }, + py::arg("payload_root"), py::arg("transform_root"), + py::arg("transform_module"), + py::arg("transform_options") = PyMlirTransformOptions()); +} + +PYBIND11_MODULE(_mlirTransformInterpreter, m) { + m.doc() = "MLIR Transform dialect interpreter functionality."; + populateTransformInterpreterSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index b2952da17..58b873904 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -198,6 +198,15 @@ add_mlir_upstream_c_api_library(MLIRCAPITransformDialect MLIRTransformDialect ) +add_mlir_upstream_c_api_library(MLIRCAPITransformDialectTransforms + TransformInterpreter.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRTransformDialectTransforms +) + add_mlir_upstream_c_api_library(MLIRCAPIQuant Quant.cpp diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp new file mode 100644 index 000000000..6a2cfb235 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -0,0 +1,74 @@ +//===- TransformTransforms.cpp - C Interface for Transform dialect --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// C interface to transforms for the transform dialect. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Transform/Interpreter.h" +#include "mlir-c/Support.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" + +using namespace mlir; + +DEFINE_C_API_PTR_METHODS(MlirTransformOptions, transform::TransformOptions) + +extern "C" { + +MlirTransformOptions mlirTransformOptionsCreate() { + return wrap(new transform::TransformOptions); +} + +void mlirTransformOptionsEnableExpensiveChecks( + MlirTransformOptions transformOptions, bool enable) { + unwrap(transformOptions)->enableExpensiveChecks(enable); +} + +bool mlirTransformOptionsGetExpensiveChecksEnabled( + MlirTransformOptions transformOptions) { + return unwrap(transformOptions)->getExpensiveChecksEnabled(); +} + +void mlirTransformOptionsEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions, bool enable) { + unwrap(transformOptions)->enableEnforceSingleToplevelTransformOp(enable); +} + +bool mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( + MlirTransformOptions transformOptions) { + return unwrap(transformOptions)->getEnforceSingleToplevelTransformOp(); +} + +void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions) { + delete unwrap(transformOptions); +} + +MlirLogicalResult mlirTransformApplyNamedSequence( + MlirOperation payload, MlirOperation transformRoot, + MlirOperation transformModule, MlirTransformOptions transformOptions) { + Operation *transformRootOp = unwrap(transformRoot); + Operation *transformModuleOp = unwrap(transformModule); + if (!isa(transformRootOp)) { + transformRootOp->emitError() + << "must implement TransformOpInterface to be used as transform root"; + return mlirLogicalResultFailure(); + } + if (!isa(transformModuleOp)) { + transformModuleOp->emitError() + << "must be a " << ModuleOp::getOperationName(); + return mlirLogicalResultFailure(); + } + return wrap(transform::applyTransformNamedSequence( + unwrap(payload), unwrap(transformRoot), + cast(unwrap(transformModule)), *unwrap(transformOptions))); +} +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ed167afeb..563d035f1 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -181,6 +181,13 @@ declare_mlir_python_sources( SOURCES dialects/transform/extras/__init__.py) +declare_mlir_python_sources( + MLIRPythonSources.Dialects.transform.interpreter + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + SOURCES + dialects/transform/interpreter/__init__.py) + declare_mlir_dialect_extension_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -609,6 +616,18 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MLIRCAPISparseTensor ) +declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter + MODULE_NAME _mlirTransformInterpreter + ADD_TO_PARENT MLIRPythonSources.Dialects.transform + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + TransformInterpreter.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPITransformDialectTransforms +) + # TODO: Figure out how to put this in the test tree. # This should not be included in the main Python extension. However, # putting it into MLIRPythonTestSources along with the dialect declaration diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py new file mode 100644 index 000000000..6145b9922 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ....ir import Operation +from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter + + +TransformOptions = _cextTransformInterpreter.TransformOptions + + +def _unpack_operation(op): + if isinstance(op, Operation): + return op + return op.operation + + +def apply_named_sequence( + payload_root, transform_root, transform_module, transform_options=None +): + """Applies the transformation script starting at the given transform root + operation to the given payload operation. The module containing the + transform root as well as the transform options should be provided. + The transform operation must implement TransformOpInterface and the module + must be a ModuleOp.""" + + args = tuple( + map(_unpack_operation, (payload_root, transform_root, transform_module)) + ) + if transform_options is None: + _cextTransformInterpreter.apply_named_sequence(*args) + else: + _cextTransformInterpreter(*args, transform_options) From 108c94be387c9789be72a71ab97d349e81943af9 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:59:23 +0000 Subject: [PATCH 680/915] [MLIR][Python] Use ir.Value directly instead of _SubClassValueT (#82341) _SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT: ... here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any. --- mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 2 +- mlir/python/mlir/dialects/_ods_common.py | 7 ------- mlir/python/mlir/dialects/arith.py | 3 +-- 3 files changed, 2 insertions(+), 10 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 3ed1872f1..93b978c75 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -10,4 +10,4 @@ class _Globals: def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... def register_dialect(dialect_class: type) -> object: ... -def register_operation(dialect_class: type) -> object: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ... diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 3af3b5ce7..1e7e8244e 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -8,7 +8,6 @@ Sequence as _Sequence, Tuple as _Tuple, Type as _Type, - TypeVar as _TypeVar, Union as _Union, ) @@ -143,12 +142,6 @@ def get_op_result_or_op_results( else op ) - -# This is the standard way to indicate subclass/inheritance relationship -# see the typing.Type doc string. -_U = _TypeVar("_U", bound=_cext.ir.Value) -SubClassValueT = _Type[_U] - ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 663a53660..61c691739 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -12,7 +12,6 @@ get_default_loc_context as _get_default_loc_context, _cext as _ods_cext, get_op_result_or_op_results as _get_op_result_or_op_results, - SubClassValueT as _SubClassValueT, ) from typing import Any, List, Union @@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]: def constant( result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None -) -> _SubClassValueT: +) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) From cd9f03cc1dbdab2082ef5a42df476cb1e0b2b182 Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Mon, 26 Feb 2024 18:08:28 -0600 Subject: [PATCH 681/915] [mlir][sparse] Introduce batch level format. (#83082) --- mlir/include/mlir-c/Dialect/SparseTensor.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 898d2f127..52ca7ba8a 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -29,10 +29,11 @@ typedef uint64_t MlirSparseTensorLevelType; enum MlirSparseTensorLevelFormat { MLIR_SPARSE_TENSOR_LEVEL_DENSE = 0x000000010000, - MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000020000, - MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000040000, - MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000080000, - MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000100000, + MLIR_SPARSE_TENSOR_LEVEL_BATCH = 0x000000020000, + MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED = 0x000000040000, + MLIR_SPARSE_TENSOR_LEVEL_SINGLETON = 0x000000080000, + MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED = 0x000000100000, + MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M = 0x000000200000, }; enum MlirSparseTensorLevelPropertyNondefault { From e0ad519c1902764032d7ee28d209f8f1d026a6a9 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 2 Mar 2024 19:10:50 -0800 Subject: [PATCH 682/915] Split the llvm::ThreadPool into an abstract base class and an implementation (#82094) This decouples the public API used to enqueue tasks and wait for completion from the actual implementation, and opens up the possibility for clients to set their own thread pool implementation for the pool. https://discourse.llvm.org/t/construct-threadpool-from-vector-of-existing-threads/76883 --- mlir/include/mlir/CAPI/Support.h | 4 ++-- mlir/lib/CAPI/IR/IR.cpp | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index 82aa05185..622745256 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -22,7 +22,7 @@ #include "llvm/ADT/StringRef.h" namespace llvm { -class ThreadPool; +class ThreadPoolInterface; } // namespace llvm /// Converts a StringRef into its MLIR C API equivalent. @@ -45,7 +45,7 @@ inline mlir::LogicalResult unwrap(MlirLogicalResult res) { return mlir::success(mlirLogicalResultIsSuccess(res)); } -DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPool) +DEFINE_C_API_PTR_METHODS(MlirLlvmThreadPool, llvm::ThreadPoolInterface) DEFINE_C_API_METHODS(MlirTypeID, mlir::TypeID) DEFINE_C_API_PTR_METHODS(MlirTypeIDAllocator, mlir::TypeIDAllocator) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index a97cfe5b0..cdb64f4ec 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" +#include "llvm/Support/ThreadPool.h" #include #include From d058698a29f725864a3f77de5c1a65c9151fbc26 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 5 Mar 2024 09:43:31 +0000 Subject: [PATCH 683/915] [mlir] Apply ClangTidy findings move constructors should be marked noexcept --- mlir/lib/Bindings/Python/TransformInterpreter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 6517f8c39..3530f295e 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -23,7 +23,7 @@ namespace py = pybind11; namespace { struct PyMlirTransformOptions { PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; - PyMlirTransformOptions(PyMlirTransformOptions &&other) { + PyMlirTransformOptions(PyMlirTransformOptions &&other) noexcept { options = other.options; other.options.ptr = nullptr; } From d5efab4525fcd109a1dc67c8d95261535bd935d1 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Tue, 5 Mar 2024 16:09:59 +0100 Subject: [PATCH 684/915] [mlir][py] better support for arith.constant construction (#83259) Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as `array.array`. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly. --- mlir/python/mlir/dialects/arith.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 61c691739..83a50c7ef 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -5,6 +5,8 @@ from ._arith_ops_gen import * from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * +from array import array as _array +from typing import overload try: from ..ir import * @@ -43,13 +45,30 @@ def _is_float_type(type: Type): class ConstantOp(ConstantOp): """Specialization for the constant op class.""" + @overload + def __init__(self, value: Attribute, *, loc=None, ip=None): + ... + + @overload def __init__( - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None ): + ... + + def __init__(self, result, value, *, loc=None, ip=None): + if value is None: + assert isinstance(result, Attribute) + super().__init__(result, loc=loc, ip=ip) + return + if isinstance(value, int): super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, _array) and value.typecode in ["i", "l"]: + super().__init__(DenseIntElementsAttr.get(value, type=result)) + elif isinstance(value, _array) and value.typecode in ["f", "d"]: + super().__init__(DenseFPElementsAttr.get(value, type=result)) else: super().__init__(value, loc=loc, ip=ip) @@ -79,6 +98,6 @@ def literal_value(self) -> Union[int, float]: def constant( - result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None ) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) From 50db0a251e71e7ed1d49a5546da4fc40980a9617 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 5 Mar 2024 18:00:46 -0800 Subject: [PATCH 685/915] Rename llvm::ThreadPool -> llvm::DefaultThreadPool (NFC) (#83702) The base class llvm::ThreadPoolInterface will be renamed llvm::ThreadPool in a subsequent commit. This is a breaking change: clients who use to create a ThreadPool must now create a DefaultThreadPool instead. --- mlir/lib/CAPI/IR/Support.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/Support.cpp b/mlir/lib/CAPI/IR/Support.cpp index 81c9fc771..3311131fc 100644 --- a/mlir/lib/CAPI/IR/Support.cpp +++ b/mlir/lib/CAPI/IR/Support.cpp @@ -25,7 +25,7 @@ bool mlirStringRefEqual(MlirStringRef string, MlirStringRef other) { // LLVM ThreadPool API. //===----------------------------------------------------------------------===// MlirLlvmThreadPool mlirLlvmThreadPoolCreate() { - return wrap(new llvm::ThreadPool()); + return wrap(new llvm::DefaultThreadPool()); } void mlirLlvmThreadPoolDestroy(MlirLlvmThreadPool threadPool) { From 307c5d801a4d9417a33c53c25b680ee4b3732eff Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 5 Mar 2024 18:57:45 -0800 Subject: [PATCH 686/915] Revert "[mlir][py] better support for arith.constant construction" (#84103) Reverts llvm/llvm-project#83259 This broke an integration test on Windows --- mlir/python/mlir/dialects/arith.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 83a50c7ef..61c691739 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -5,8 +5,6 @@ from ._arith_ops_gen import * from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * -from array import array as _array -from typing import overload try: from ..ir import * @@ -45,30 +43,13 @@ def _is_float_type(type: Type): class ConstantOp(ConstantOp): """Specialization for the constant op class.""" - @overload - def __init__(self, value: Attribute, *, loc=None, ip=None): - ... - - @overload def __init__( - self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None + self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None ): - ... - - def __init__(self, result, value, *, loc=None, ip=None): - if value is None: - assert isinstance(result, Attribute) - super().__init__(result, loc=loc, ip=ip) - return - if isinstance(value, int): super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) - elif isinstance(value, _array) and value.typecode in ["i", "l"]: - super().__init__(DenseIntElementsAttr.get(value, type=result)) - elif isinstance(value, _array) and value.typecode in ["f", "d"]: - super().__init__(DenseFPElementsAttr.get(value, type=result)) else: super().__init__(value, loc=loc, ip=ip) @@ -98,6 +79,6 @@ def literal_value(self) -> Union[int, float]: def constant( - result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None + result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None ) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) From e6d329565bf9a0c617c970ba0b647d7961202202 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 6 Mar 2024 09:31:44 +0000 Subject: [PATCH 687/915] [mlir] Remove noexcept again from move constructors. LLVM does not have the corresponding ClangTidy check enabled, so we should not be fixing such findings. --- mlir/lib/Bindings/Python/TransformInterpreter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 3530f295e..6517f8c39 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -23,7 +23,7 @@ namespace py = pybind11; namespace { struct PyMlirTransformOptions { PyMlirTransformOptions() { options = mlirTransformOptionsCreate(); }; - PyMlirTransformOptions(PyMlirTransformOptions &&other) noexcept { + PyMlirTransformOptions(PyMlirTransformOptions &&other) { options = other.options; other.options.ptr = nullptr; } From 553d30457b8149d2e0c932df28ec35e41e73d2cd Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Thu, 7 Mar 2024 17:14:08 +0100 Subject: [PATCH 688/915] Reapply "[mlir][py] better support for arith.constant construction" (#84142) Arithmetic constants for vector types can be constructed from objects implementing Python buffer protocol such as `array.array`. Note that until Python 3.12, there is no typing support for buffer protocol implementers, so the annotations use array explicitly. Reverts llvm/llvm-project#84103 --- mlir/python/mlir/dialects/arith.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 61c691739..92da5df9b 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -5,6 +5,8 @@ from ._arith_ops_gen import * from ._arith_ops_gen import _Dialect from ._arith_enum_gen import * +from array import array as _array +from typing import overload try: from ..ir import * @@ -43,13 +45,37 @@ def _is_float_type(type: Type): class ConstantOp(ConstantOp): """Specialization for the constant op class.""" + @overload + def __init__(self, value: Attribute, *, loc=None, ip=None): + ... + + @overload def __init__( - self, result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + self, result: Type, value: Union[int, float, _array], *, loc=None, ip=None ): + ... + + def __init__(self, result, value, *, loc=None, ip=None): + if value is None: + assert isinstance(result, Attribute) + super().__init__(result, loc=loc, ip=ip) + return + if isinstance(value, int): super().__init__(IntegerAttr.get(result, value), loc=loc, ip=ip) elif isinstance(value, float): super().__init__(FloatAttr.get(result, value), loc=loc, ip=ip) + elif isinstance(value, _array): + if 8 * value.itemsize != result.element_type.width: + raise ValueError( + f"Mismatching array element ({8 * value.itemsize}) and type ({result.element_type.width}) width." + ) + if value.typecode in ["i", "l", "q"]: + super().__init__(DenseIntElementsAttr.get(value, type=result)) + elif value.typecode in ["f", "d"]: + super().__init__(DenseFPElementsAttr.get(value, type=result)) + else: + raise ValueError(f'Unsupported typecode: "{value.typecode}".') else: super().__init__(value, loc=loc, ip=ip) @@ -79,6 +105,6 @@ def literal_value(self) -> Union[int, float]: def constant( - result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None + result: Type, value: Union[int, float, Attribute, _array], *, loc=None, ip=None ) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) From 1c98a4b836db5a5c10bb01f158aac66312bf5250 Mon Sep 17 00:00:00 2001 From: Edgar Date: Thu, 7 Mar 2024 18:10:46 +0100 Subject: [PATCH 689/915] [MLIR] Add llvm (debug) attributes to CAPI (#83992) This PR adds the following to the mlir c api: - The disctinct mlir builtin attribute. - LLVM attributes (mostly debug related ones) --- mlir/include/mlir-c/BuiltinAttributes.h | 4 + mlir/include/mlir-c/Dialect/LLVM.h | 231 ++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 207 +++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 4 + 4 files changed, 446 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 01d1b6008..231eb83b5 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -266,6 +266,10 @@ mlirSymbolRefAttrGetNestedReference(MlirAttribute attr, intptr_t pos); /// Returns the typeID of an SymbolRef attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirSymbolRefAttrGetTypeID(void); +/// Creates a DisctinctAttr with the referenced attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirDisctinctAttrCreate(MlirAttribute referencedAttr); + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index ac216b01f..d823afb65 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -11,6 +11,7 @@ #define MLIR_C_DIALECT_LLVM_H #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -98,6 +99,236 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType, intptr_t nFieldTypes, MlirType const *fieldTypes, bool isPacked); +enum MlirLLVMCConv { + MlirLLVMCConvC = 0, + MlirLLVMCConvFast = 8, + MlirLLVMCConvCold = 9, + MlirLLVMCConvGHC = 10, + MlirLLVMCConvHiPE = 11, + MlirLLVMCConvAnyReg = 13, + MlirLLVMCConvPreserveMost = 14, + MlirLLVMCConvPreserveAll = 15, + MlirLLVMCConvSwift = 16, + MlirLLVMCConvCXX_FAST_TLS = 17, + MlirLLVMCConvTail = 18, + MlirLLVMCConvCFGuard_Check = 19, + MlirLLVMCConvSwiftTail = 20, + MlirLLVMCConvX86_StdCall = 64, + MlirLLVMCConvX86_FastCall = 65, + MlirLLVMCConvARM_APCS = 66, + MlirLLVMCConvARM_AAPCS = 67, + MlirLLVMCConvARM_AAPCS_VFP = 68, + MlirLLVMCConvMSP430_INTR = 69, + MlirLLVMCConvX86_ThisCall = 70, + MlirLLVMCConvPTX_Kernel = 71, + MlirLLVMCConvPTX_Device = 72, + MlirLLVMCConvSPIR_FUNC = 75, + MlirLLVMCConvSPIR_KERNEL = 76, + MlirLLVMCConvIntel_OCL_BI = 77, + MlirLLVMCConvX86_64_SysV = 78, + MlirLLVMCConvWin64 = 79, + MlirLLVMCConvX86_VectorCall = 80, + MlirLLVMCConvDUMMY_HHVM = 81, + MlirLLVMCConvDUMMY_HHVM_C = 82, + MlirLLVMCConvX86_INTR = 83, + MlirLLVMCConvAVR_INTR = 84, + MlirLLVMCConvAVR_BUILTIN = 86, + MlirLLVMCConvAMDGPU_VS = 87, + MlirLLVMCConvAMDGPU_GS = 88, + MlirLLVMCConvAMDGPU_CS = 90, + MlirLLVMCConvAMDGPU_KERNEL = 91, + MlirLLVMCConvX86_RegCall = 92, + MlirLLVMCConvAMDGPU_HS = 93, + MlirLLVMCConvMSP430_BUILTIN = 94, + MlirLLVMCConvAMDGPU_LS = 95, + MlirLLVMCConvAMDGPU_ES = 96, + MlirLLVMCConvAArch64_VectorCall = 97, + MlirLLVMCConvAArch64_SVE_VectorCall = 98, + MlirLLVMCConvWASM_EmscriptenInvoke = 99, + MlirLLVMCConvAMDGPU_Gfx = 100, + MlirLLVMCConvM68k_INTR = 101, +}; +typedef enum MlirLLVMCConv MlirLLVMCConv; + +/// Creates a LLVM CConv attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMCConvAttrGet(MlirContext ctx, + MlirLLVMCConv cconv); + +enum MlirLLVMComdat { + MlirLLVMComdatAny = 0, + MlirLLVMComdatExactMatch = 1, + MlirLLVMComdatLargest = 2, + MlirLLVMComdatNoDeduplicate = 3, + MlirLLVMComdatSameSize = 4, +}; +typedef enum MlirLLVMComdat MlirLLVMComdat; + +/// Creates a LLVM Comdat attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMComdatAttrGet(MlirContext ctx, + MlirLLVMComdat comdat); + +enum MlirLLVMLinkage { + MlirLLVMLinkagePrivate = 0, + MlirLLVMLinkageInternal = 1, + MlirLLVMLinkageAvailableExternally = 2, + MlirLLVMLinkageLinkonce = 3, + MlirLLVMLinkageWeak = 4, + MlirLLVMLinkageCommon = 5, + MlirLLVMLinkageAppending = 6, + MlirLLVMLinkageExternWeak = 7, + MlirLLVMLinkageLinkonceODR = 8, + MlirLLVMLinkageWeakODR = 9, + MlirLLVMLinkageExternal = 10, +}; +typedef enum MlirLLVMLinkage MlirLLVMLinkage; + +/// Creates a LLVM Linkage attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMLinkageAttrGet(MlirContext ctx, MlirLLVMLinkage linkage); + +/// Creates a LLVM DINullType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx); + +/// Creates a LLVM DIExpressionElem attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIExpressionElemAttrGet(MlirContext ctx, unsigned int opcode, + intptr_t nArguments, uint64_t const *arguments); + +/// Creates a LLVM DIExpression attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIExpressionAttrGet( + MlirContext ctx, intptr_t nOperations, MlirAttribute const *operations); + +enum MlirLLVMTypeEncoding { + MlirLLVMTypeEncodingAddress = 0x1, + MlirLLVMTypeEncodingBoolean = 0x2, + MlirLLVMTypeEncodingComplexFloat = 0x31, + MlirLLVMTypeEncodingFloatT = 0x4, + MlirLLVMTypeEncodingSigned = 0x5, + MlirLLVMTypeEncodingSignedChar = 0x6, + MlirLLVMTypeEncodingUnsigned = 0x7, + MlirLLVMTypeEncodingUnsignedChar = 0x08, + MlirLLVMTypeEncodingImaginaryFloat = 0x09, + MlirLLVMTypeEncodingPackedDecimal = 0x0a, + MlirLLVMTypeEncodingNumericString = 0x0b, + MlirLLVMTypeEncodingEdited = 0x0c, + MlirLLVMTypeEncodingSignedFixed = 0x0d, + MlirLLVMTypeEncodingUnsignedFixed = 0x0e, + MlirLLVMTypeEncodingDecimalFloat = 0x0f, + MlirLLVMTypeEncodingUTF = 0x10, + MlirLLVMTypeEncodingUCS = 0x11, + MlirLLVMTypeEncodingASCII = 0x12, + MlirLLVMTypeEncodingLoUser = 0x80, + MlirLLVMTypeEncodingHiUser = 0xff, +}; +typedef enum MlirLLVMTypeEncoding MlirLLVMTypeEncoding; + +/// Creates a LLVM DIBasicType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + MlirLLVMTypeEncoding encoding); + +/// Creates a LLVM DICompositeType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute file, + uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, + uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, + MlirAttribute const *elements); + +/// Creates a LLVM DIDerivedType attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, + MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, + uint64_t offsetInBits); + +/// Gets the base type from a LLVM DIDerivedType attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIDerivedTypeAttrGetBaseType(MlirAttribute diDerivedType); + +/// Creates a LLVM DIFileAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFileAttrGet(MlirContext ctx, + MlirAttribute name, + MlirAttribute directory); + +enum MlirLLVMDIEmissionKind { + MlirLLVMDIEmissionKindNone = 0, + MlirLLVMDIEmissionKindFull = 1, + MlirLLVMDIEmissionKindLineTablesOnly = 2, + MlirLLVMDIEmissionKindDebugDirectivesOnly = 3, +}; +typedef enum MlirLLVMDIEmissionKind MlirLLVMDIEmissionKind; + +/// Creates a LLVM DICompileUnit attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompileUnitAttrGet( + MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, + MlirAttribute file, MlirAttribute producer, bool isOptimized, + MlirLLVMDIEmissionKind emissionKind); + +/// Creates a LLVM DIFlags attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, + uint64_t value); + +/// Creates a LLVM DILexicalBlock attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILexicalBlockAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute file, unsigned int line, + unsigned int column); + +/// Creates a LLVM DILexicalBlockFile attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILexicalBlockFileAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute file, + unsigned int discriminator); + +/// Creates a LLVM DILocalVariableAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute name, + MlirAttribute diFile, unsigned int line, unsigned int arg, + unsigned int alignInBits, MlirAttribute diType); + +/// Creates a LLVM DISubprogramAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( + MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, + MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, + MlirAttribute file, unsigned int line, unsigned int scopeLine, + uint64_t subprogramFlags, MlirAttribute type); + +/// Gets the scope from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram); + +/// Gets the line from this DISubprogramAttr. +MLIR_CAPI_EXPORTED unsigned int +mlirLLVMDISubprogramAttrGetLine(MlirAttribute diSubprogram); + +/// Gets the scope line from this DISubprogram. +MLIR_CAPI_EXPORTED unsigned int +mlirLLVMDISubprogramAttrGetScopeLine(MlirAttribute diSubprogram); + +/// Gets the compile unit from this DISubprogram. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetCompileUnit(MlirAttribute diSubprogram); + +/// Gets the file from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetFile(MlirAttribute diSubprogram); + +/// Gets the type from this DISubprogramAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetType(MlirAttribute diSubprogram); + +/// Creates a LLVM DISubroutineTypeAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, unsigned int callingConvention, + intptr_t nTypes, MlirAttribute const *types); + +/// Creates a LLVM DIModuleAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( + MlirContext ctx, MlirAttribute file, MlirAttribute scope, + MlirAttribute name, MlirAttribute configMacros, MlirAttribute includePath, + MlirAttribute apinotes, unsigned int line, bool isDecl); + +/// Gets the scope of this DIModuleAttr. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 642018a81..2d938ce5f 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -7,9 +7,16 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/LLVM.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" #include "mlir/CAPI/Registration.h" +#include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" +#include "llvm-c/Core.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" using namespace mlir; using namespace mlir::LLVM; @@ -110,3 +117,203 @@ MlirLogicalResult mlirLLVMStructTypeSetBody(MlirType structType, cast(unwrap(structType)) .setBody(unwrapList(nFieldTypes, fieldTypes, fields), isPacked)); } + +MlirAttribute mlirLLVMDIExpressionElemAttrGet(MlirContext ctx, + unsigned int opcode, + intptr_t nArguments, + uint64_t const *arguments) { + auto list = ArrayRef(arguments, nArguments); + return wrap(DIExpressionElemAttr::get(unwrap(ctx), opcode, list)); +} + +MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, + MlirAttribute const *operations) { + SmallVector attrStorage; + attrStorage.reserve(nOperations); + + return wrap(DIExpressionAttr::get( + unwrap(ctx), + llvm::map_to_vector( + unwrapList(nOperations, operations, attrStorage), + [](Attribute a) { return a.cast(); }))); +} + +MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { + return wrap(DINullTypeAttr::get(unwrap(ctx))); +} + +MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, + MlirAttribute name, + uint64_t sizeInBits, + MlirLLVMTypeEncoding encoding) { + + return wrap(DIBasicTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, encoding)); +} + +MlirAttribute mlirLLVMDICompositeTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute file, + uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, + uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, + MlirAttribute const *elements) { + SmallVector elementsStorage; + elementsStorage.reserve(nElements); + + return wrap(DICompositeTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), + cast(unwrap(file)), line, cast(unwrap(scope)), + cast(unwrap(baseType)), DIFlags(flags), sizeInBits, + alignInBits, + llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), + [](Attribute a) { return a.cast(); }))); +} + +MlirAttribute mlirLLVMDIDerivedTypeAttrGet(MlirContext ctx, unsigned int tag, + MlirAttribute name, + MlirAttribute baseType, + uint64_t sizeInBits, + uint32_t alignInBits, + uint64_t offsetInBits) { + return wrap(DIDerivedTypeAttr::get(unwrap(ctx), tag, + cast(unwrap(name)), + cast(unwrap(baseType)), + sizeInBits, alignInBits, offsetInBits)); +} + +MlirAttribute +mlirLLVMDIDerivedTypeAttrGetBaseType(MlirAttribute diDerivedType) { + return wrap(cast(unwrap(diDerivedType)).getBaseType()); +} + +MlirAttribute mlirLLVMCConvAttrGet(MlirContext ctx, MlirLLVMCConv cconv) { + return wrap(CConvAttr::get(unwrap(ctx), CConv(cconv))); +} + +MlirAttribute mlirLLVMComdatAttrGet(MlirContext ctx, MlirLLVMComdat comdat) { + return wrap(ComdatAttr::get(unwrap(ctx), comdat::Comdat(comdat))); +} + +MlirAttribute mlirLLVMLinkageAttrGet(MlirContext ctx, MlirLLVMLinkage linkage) { + return wrap(LinkageAttr::get(unwrap(ctx), linkage::Linkage(linkage))); +} + +MlirAttribute mlirLLVMDIFileAttrGet(MlirContext ctx, MlirAttribute name, + MlirAttribute directory) { + return wrap(DIFileAttr::get(unwrap(ctx), cast(unwrap(name)), + cast(unwrap(directory)))); +} + +MlirAttribute +mlirLLVMDICompileUnitAttrGet(MlirContext ctx, MlirAttribute id, + unsigned int sourceLanguage, MlirAttribute file, + MlirAttribute producer, bool isOptimized, + MlirLLVMDIEmissionKind emissionKind) { + return wrap(DICompileUnitAttr::get( + unwrap(ctx), cast(unwrap(id)), sourceLanguage, + cast(unwrap(file)), cast(unwrap(producer)), + isOptimized, DIEmissionKind(emissionKind))); +} + +MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, uint64_t value) { + return wrap(DIFlagsAttr::get(unwrap(ctx), DIFlags(value))); +} + +MlirAttribute mlirLLVMDILexicalBlockAttrGet(MlirContext ctx, + MlirAttribute scope, + MlirAttribute file, + unsigned int line, + unsigned int column) { + return wrap( + DILexicalBlockAttr::get(unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(file)), line, column)); +} + +MlirAttribute mlirLLVMDILexicalBlockFileAttrGet(MlirContext ctx, + MlirAttribute scope, + MlirAttribute file, + unsigned int discriminator) { + return wrap(DILexicalBlockFileAttr::get( + unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(file)), discriminator)); +} + +MlirAttribute +mlirLLVMDILocalVariableAttrGet(MlirContext ctx, MlirAttribute scope, + MlirAttribute name, MlirAttribute diFile, + unsigned int line, unsigned int arg, + unsigned int alignInBits, MlirAttribute diType) { + return wrap(DILocalVariableAttr::get( + unwrap(ctx), cast(unwrap(scope)), + cast(unwrap(name)), cast(unwrap(diFile)), line, + arg, alignInBits, cast(unwrap(diType)))); +} + +MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, + unsigned int callingConvention, + intptr_t nTypes, + MlirAttribute const *types) { + SmallVector attrStorage; + attrStorage.reserve(nTypes); + + return wrap(DISubroutineTypeAttr::get( + unwrap(ctx), callingConvention, + llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), + [](Attribute a) { return a.cast(); }))); +} + +MlirAttribute mlirLLVMDISubprogramAttrGet( + MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, + MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, + MlirAttribute file, unsigned int line, unsigned int scopeLine, + uint64_t subprogramFlags, MlirAttribute type) { + return wrap(DISubprogramAttr::get( + unwrap(ctx), cast(unwrap(id)), + cast(unwrap(compileUnit)), + cast(unwrap(scope)), cast(unwrap(name)), + cast(unwrap(linkageName)), cast(unwrap(file)), + line, scopeLine, DISubprogramFlags(subprogramFlags), + cast(unwrap(type)))); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getScope()); +} + +unsigned int mlirLLVMDISubprogramAttrGetLine(MlirAttribute diSubprogram) { + return cast(unwrap(diSubprogram)).getLine(); +} + +unsigned int mlirLLVMDISubprogramAttrGetScopeLine(MlirAttribute diSubprogram) { + return cast(unwrap(diSubprogram)).getScopeLine(); +} + +MlirAttribute +mlirLLVMDISubprogramAttrGetCompileUnit(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getCompileUnit()); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetFile(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getFile()); +} + +MlirAttribute mlirLLVMDISubprogramAttrGetType(MlirAttribute diSubprogram) { + return wrap(cast(unwrap(diSubprogram)).getType()); +} + +MlirAttribute mlirLLVMDIModuleAttrGet(MlirContext ctx, MlirAttribute file, + MlirAttribute scope, MlirAttribute name, + MlirAttribute configMacros, + MlirAttribute includePath, + MlirAttribute apinotes, unsigned int line, + bool isDecl) { + return wrap(DIModuleAttr::get( + unwrap(ctx), cast(unwrap(file)), + cast(unwrap(scope)), cast(unwrap(name)), + cast(unwrap(configMacros)), + cast(unwrap(includePath)), cast(unwrap(apinotes)), + line, isDecl)); +} + +MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { + return wrap(cast(unwrap(diModule)).getScope()); +} diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index b3066ee0c..726af8846 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -289,6 +289,10 @@ MlirTypeID mlirSymbolRefAttrGetTypeID(void) { return wrap(SymbolRefAttr::getTypeID()); } +MlirAttribute mlirDisctinctAttrCreate(MlirAttribute referencedAttr) { + return wrap(mlir::DistinctAttr::create(unwrap(referencedAttr))); +} + //===----------------------------------------------------------------------===// // Flat SymbolRef attribute. //===----------------------------------------------------------------------===// From a05f4fc6d3d9f7df535c84bfd185174cd5d6ff48 Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Fri, 15 Mar 2024 09:58:25 -0700 Subject: [PATCH 690/915] [MLIR][LLVM] Support Recursive DITypes (#80251) Following the discussion from [this thread](https://discourse.llvm.org/t/handling-cyclic-dependencies-in-debug-info/67526/11), this PR adds support for recursive DITypes. This PR adds: 1. DIRecursiveTypeAttrInterface: An interface that DITypeAttrs can implement to indicate that it supports recursion. See full description in code. 2. Importer & exporter support (The only DITypeAttr that implements the interface is DICompositeTypeAttr, so the exporter is only implemented for composites too. There will be two methods that each llvm DI type that supports mutation needs to implement since there's nothing general). --------- Co-authored-by: Tobias Gysi --- mlir/include/mlir-c/Dialect/LLVM.h | 8 ++++---- mlir/lib/CAPI/Dialect/LLVM.cpp | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index d823afb65..b3d7a788c 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -229,10 +229,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( /// Creates a LLVM DICompositeType attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute file, - uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, - uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, - MlirAttribute const *elements); + MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, + MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements); /// Creates a LLVM DIDerivedType attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 2d938ce5f..d0fd5ceec 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -152,18 +152,18 @@ MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, } MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute file, - uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, - uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, - MlirAttribute const *elements) { + MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, + MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements) { SmallVector elementsStorage; elementsStorage.reserve(nElements); return wrap(DICompositeTypeAttr::get( - unwrap(ctx), tag, cast(unwrap(name)), - cast(unwrap(file)), line, cast(unwrap(scope)), - cast(unwrap(baseType)), DIFlags(flags), sizeInBits, - alignInBits, + unwrap(ctx), tag, cast(unwrap(recId)), + cast(unwrap(name)), cast(unwrap(file)), line, + cast(unwrap(scope)), cast(unwrap(baseType)), + DIFlags(flags), sizeInBits, alignInBits, llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), [](Attribute a) { return a.cast(); }))); } From 8e8066cc42d65ee960c2ad8cdbab986777c752f1 Mon Sep 17 00:00:00 2001 From: Christian Ulmann Date: Wed, 20 Mar 2024 16:08:38 +0100 Subject: [PATCH 691/915] [MLIR][LLVM] Add extraData field to the DIDerivedType attribute (#85935) This commit extends the DIDerivedTypeAttr with the `extraData` field. For now, the type of it is limited to be a `DINodeAttr`, as extending the debug metadata handling to support arbitrary metadata nodes does not seem to be necessary so far. --- mlir/include/mlir-c/Dialect/LLVM.h | 2 +- mlir/lib/CAPI/Dialect/LLVM.cpp | 19 +++++++++---------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index b3d7a788c..4f1d646f5 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -238,7 +238,7 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, - uint64_t offsetInBits); + uint64_t offsetInBits, MlirAttribute extraData); /// Gets the base type from a LLVM DIDerivedType attribute. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index d0fd5ceec..71f2b73dd 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -168,16 +168,15 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( [](Attribute a) { return a.cast(); }))); } -MlirAttribute mlirLLVMDIDerivedTypeAttrGet(MlirContext ctx, unsigned int tag, - MlirAttribute name, - MlirAttribute baseType, - uint64_t sizeInBits, - uint32_t alignInBits, - uint64_t offsetInBits) { - return wrap(DIDerivedTypeAttr::get(unwrap(ctx), tag, - cast(unwrap(name)), - cast(unwrap(baseType)), - sizeInBits, alignInBits, offsetInBits)); +MlirAttribute +mlirLLVMDIDerivedTypeAttrGet(MlirContext ctx, unsigned int tag, + MlirAttribute name, MlirAttribute baseType, + uint64_t sizeInBits, uint32_t alignInBits, + uint64_t offsetInBits, MlirAttribute extraData) { + return wrap(DIDerivedTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), + cast(unwrap(baseType)), sizeInBits, alignInBits, offsetInBits, + cast(unwrap(extraData)))); } MlirAttribute From b59b7840466924a3105fceb0704f8145c8b09d9c Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Wed, 20 Mar 2024 15:56:22 +0000 Subject: [PATCH 692/915] [mlir][python] Enable python bindings for Index dialect (#85827) This small patch enables python bindings for the index dialect. --------- Co-authored-by: Steven Varoumas --- mlir/python/CMakeLists.txt | 9 +++++++++ mlir/python/mlir/dialects/IndexOps.td | 14 ++++++++++++++ mlir/python/mlir/dialects/index.py | 6 ++++++ 3 files changed, 29 insertions(+) create mode 100644 mlir/python/mlir/dialects/IndexOps.td create mode 100644 mlir/python/mlir/dialects/index.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 563d035f1..c27ee688a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -108,6 +108,15 @@ declare_mlir_dialect_python_bindings( dialects/complex.py DIALECT_NAME complex) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/IndexOps.td + SOURCES + dialects/index.py + DIALECT_NAME index + GEN_ENUM_BINDINGS) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/IndexOps.td b/mlir/python/mlir/dialects/IndexOps.td new file mode 100644 index 000000000..13b1d782c --- /dev/null +++ b/mlir/python/mlir/dialects/IndexOps.td @@ -0,0 +1,14 @@ +//===-- IndexOps.td - Entry point for Index bindings -----*- tablegen -*---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_INDEX_OPS +#define PYTHON_BINDINGS_INDEX_OPS + +include "mlir/Dialect/Index/IR/IndexOps.td" + +#endif diff --git a/mlir/python/mlir/dialects/index.py b/mlir/python/mlir/dialects/index.py new file mode 100644 index 000000000..73708c7d7 --- /dev/null +++ b/mlir/python/mlir/dialects/index.py @@ -0,0 +1,6 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._index_ops_gen import * +from ._index_enum_gen import * From 8607b2ec3725cc336d2da5b4be684a6896ee1505 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 20 Mar 2024 22:15:17 +0100 Subject: [PATCH 693/915] [mlir] split transform interfaces into a separate library (#85221) Transform interfaces are implemented, direction or via extensions, in libraries belonging to multiple other dialects. Those dialects don't need to depend on the non-interface part of the transform dialect, which includes the growing number of ops and transitive dependency footprint. Split out the interfaces into a separate library. This in turn requires flipping the dependency from the interface on the dialect that has crept in because both co-existed in one library. The interface shouldn't depend on the transform dialect either. As a consequence of splitting, the capability of the interpreter to automatically walk the payload IR to identify payload ops of a certain kind based on the type used for the entry point symbol argument is disabled. This is a good move by itself as it simplifies the interpreter logic. This functionality can be trivially replaced by a `transform.structured.match` operation. --- mlir/lib/CAPI/Dialect/TransformInterpreter.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp index 6a2cfb235..eb6951dc5 100644 --- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -15,7 +15,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" using namespace mlir; From ed63cac12f4826806215675bb861679c1ee19756 Mon Sep 17 00:00:00 2001 From: Kirill Podoprigora Date: Mon, 1 Apr 2024 22:07:10 +0300 Subject: [PATCH 694/915] [mlir] Remove ``dataclasses`` package from mlir ``requirements.txt`` (#87223) The ``dataclasses`` package makes sense for Python 3.6, becauses ``dataclasses`` is only included in the standard library with 3.7 version. Now, 3.6 has reached EOL, so all current supported versions of Python (3.8, 3.9, 3.10, 3.11, 3.12) have this feature in their standard libraries. Therefore there's no need to install the ``dataclasses`` package now. --- mlir/python/requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index a596f8747..acd6dbb25 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 -dataclasses>=0.6, <=0.8 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file From 5a14eaafc41cd249beaf6044a14e540d8da82451 Mon Sep 17 00:00:00 2001 From: Billy Zhu Date: Tue, 9 Apr 2024 06:18:07 -0700 Subject: [PATCH 695/915] [MLIR][LLVM] Add DebugNameTableKind to DICompileUnit (#87974) Add the DebugNameTableKind field to DICompileUnit, along with its importer & exporter. --- mlir/include/mlir-c/Dialect/LLVM.h | 10 +++++++++- mlir/lib/CAPI/Dialect/LLVM.cpp | 6 ++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 4f1d646f5..bd9b7dd26 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -257,11 +257,19 @@ enum MlirLLVMDIEmissionKind { }; typedef enum MlirLLVMDIEmissionKind MlirLLVMDIEmissionKind; +enum MlirLLVMDINameTableKind { + MlirLLVMDINameTableKindDefault = 0, + MlirLLVMDINameTableKindGNU = 1, + MlirLLVMDINameTableKindNone = 2, + MlirLLVMDINameTableKindApple = 3, +}; +typedef enum MlirLLVMDINameTableKind MlirLLVMDINameTableKind; + /// Creates a LLVM DICompileUnit attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompileUnitAttrGet( MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, MlirAttribute file, MlirAttribute producer, bool isOptimized, - MlirLLVMDIEmissionKind emissionKind); + MlirLLVMDIEmissionKind emissionKind, MlirLLVMDINameTableKind nameTableKind); /// Creates a LLVM DIFlags attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 71f2b73dd..4669c40f8 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -206,11 +206,13 @@ MlirAttribute mlirLLVMDICompileUnitAttrGet(MlirContext ctx, MlirAttribute id, unsigned int sourceLanguage, MlirAttribute file, MlirAttribute producer, bool isOptimized, - MlirLLVMDIEmissionKind emissionKind) { + MlirLLVMDIEmissionKind emissionKind, + MlirLLVMDINameTableKind nameTableKind) { return wrap(DICompileUnitAttr::get( unwrap(ctx), cast(unwrap(id)), sourceLanguage, cast(unwrap(file)), cast(unwrap(producer)), - isOptimized, DIEmissionKind(emissionKind))); + isOptimized, DIEmissionKind(emissionKind), + DINameTableKind(nameTableKind))); } MlirAttribute mlirLLVMDIFlagsAttrGet(MlirContext ctx, uint64_t value) { From f130ac075e2a45e4bd8cd5bb63cf62ee7f12096d Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Wed, 17 Apr 2024 15:09:47 +0900 Subject: [PATCH 696/915] [mlir][python] Add `walk` method to PyOperationBase (#87962) This commit adds `walk` method to PyOperationBase that uses a python object as a callback, e.g. `op.walk(callback)`. Currently callback must return a walk result explicitly. We(SiFive) have implemented walk method with python in our internal python tool for a while. However the overhead of python is expensive and it didn't scale well for large MLIR files. Just replacing walk with this version reduced the entire execution time of the tool by 30~40% and there are a few configs that the tool takes several hours to finish so this commit significantly improves tool performance. --- mlir/include/mlir-c/IR.h | 10 +++++- .../mlir/Bindings/Python/PybindAdaptors.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 32 +++++++++++++++++-- mlir/lib/Bindings/Python/IRModule.h | 4 +++ mlir/lib/CAPI/IR/IR.cpp | 21 ++++++++++-- 5 files changed, 62 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 82da511f8..32abacf35 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -705,6 +705,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); +/// Operation walk result. +typedef enum MlirWalkResult { + MlirWalkResultAdvance, + MlirWalkResultInterrupt, + MlirWalkResultSkip +} MlirWalkResult; + /// Traversal order for operation walk. typedef enum MlirWalkOrder { MlirWalkPreOrder, @@ -713,7 +720,8 @@ typedef enum MlirWalkOrder { /// Operation walker type. The handler is passed an (opaque) reference to an /// operation and a pointer to a `userData`. -typedef void (*MlirOperationWalkCallback)(MlirOperation, void *userData); +typedef MlirWalkResult (*MlirOperationWalkCallback)(MlirOperation, + void *userData); /// Walks operation `op` in `walkOrder` and calls `callback` on that operation. /// `*userData` is passed to the callback as well and can be used to tunnel some diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 52f632125..d8f22c7aa 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -18,6 +18,7 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H #define MLIR_BINDINGS_PYTHON_PYBINDADAPTORS_H +#include #include #include #include diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 734f2f7f3..d875f4eba 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -674,6 +674,7 @@ void PyMlirContext::clearOperationsInside(PyOperationBase &op) { data->rootOp.getOperation().getContext()->clearOperation(op); else data->rootSeen = true; + return MlirWalkResult::MlirWalkResultAdvance; }; mlirOperationWalk(op.getOperation(), invalidatingCallback, static_cast(&data), MlirWalkPreOrder); @@ -1249,6 +1250,21 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, .str()); } +void PyOperationBase::walk( + std::function callback, + MlirWalkOrder walkOrder) { + PyOperation &operation = getOperation(); + operation.checkValid(); + MlirOperationWalkCallback walkCallback = [](MlirOperation op, + void *userData) { + auto *fn = + static_cast *>(userData); + return (*fn)(op); + }; + + mlirOperationWalk(operation, walkCallback, &callback, walkOrder); +} + py::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, @@ -2511,6 +2527,15 @@ void mlir::python::populateIRCore(py::module &m) { .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); + py::enum_(m, "WalkOrder", py::module_local()) + .value("PRE_ORDER", MlirWalkPreOrder) + .value("POST_ORDER", MlirWalkPostOrder); + + py::enum_(m, "WalkResult", py::module_local()) + .value("ADVANCE", MlirWalkResultAdvance) + .value("INTERRUPT", MlirWalkResultInterrupt) + .value("SKIP", MlirWalkResultSkip); + //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- @@ -2989,8 +3014,7 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool>( - &PyOperationBase::print), + bool, py::object, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3038,7 +3062,9 @@ void mlir::python::populateIRCore(py::module &m) { return operation.createOpView(); }, "Detaches the operation from its parent block.") - .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }); + .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) + .def("walk", &PyOperationBase::walk, py::arg("callback"), + py::arg("walk_order") = MlirWalkPostOrder); py::class_(m, "Operation", py::module_local()) .def_static("create", &PyOperation::create, py::arg("name"), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9acfdde25..b038a0c54 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -579,6 +579,10 @@ class PyOperationBase { void writeBytecode(const pybind11::object &fileObject, std::optional bytecodeVersion); + // Implement the walk method. + void walk(std::function callback, + MlirWalkOrder walkOrder); + /// Moves the operation before or after the other operation. void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index cdb64f4ec..a72cd247e 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -717,17 +717,34 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { return unwrap(op)->moveBefore(unwrap(other)); } +static mlir::WalkResult unwrap(MlirWalkResult result) { + switch (result) { + case MlirWalkResultAdvance: + return mlir::WalkResult::advance(); + + case MlirWalkResultInterrupt: + return mlir::WalkResult::interrupt(); + + case MlirWalkResultSkip: + return mlir::WalkResult::skip(); + } +} + void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, void *userData, MlirWalkOrder walkOrder) { switch (walkOrder) { case MlirWalkPreOrder: unwrap(op)->walk( - [callback, userData](Operation *op) { callback(wrap(op), userData); }); + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); break; case MlirWalkPostOrder: unwrap(op)->walk( - [callback, userData](Operation *op) { callback(wrap(op), userData); }); + [callback, userData](Operation *op) { + return unwrap(callback(wrap(op), userData)); + }); } } From ce0da5bac23000d7384a072ac1d60879bc27eee8 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 17 Apr 2024 15:01:59 +0200 Subject: [PATCH 697/915] [mlir] expose transform dialect symbol merge to python (#87690) This functionality is available in C++, make it available in Python directly to operate on transform modules. --- .../mlir-c/Dialect/Transform/Interpreter.h | 12 +++++++++++- mlir/lib/Bindings/Python/TransformInterpreter.cpp | 15 +++++++++++++++ mlir/lib/CAPI/Dialect/TransformInterpreter.cpp | 9 +++++++++ .../dialects/transform/interpreter/__init__.py | 10 +++++++++- 4 files changed, 44 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h index 00095d504..fa3203242 100644 --- a/mlir/include/mlir-c/Dialect/Transform/Interpreter.h +++ b/mlir/include/mlir-c/Dialect/Transform/Interpreter.h @@ -60,7 +60,7 @@ MLIR_CAPI_EXPORTED void mlirTransformOptionsDestroy(MlirTransformOptions transformOptions); //----------------------------------------------------------------------------// -// Transform interpreter. +// Transform interpreter and utilities. //----------------------------------------------------------------------------// /// Applies the transformation script starting at the given transform root @@ -72,6 +72,16 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirTransformApplyNamedSequence( MlirOperation payload, MlirOperation transformRoot, MlirOperation transformModule, MlirTransformOptions transformOptions); +/// Merge the symbols from `other` into `target`, potentially renaming them to +/// avoid conflicts. Private symbols may be renamed during the merge, public +/// symbols must have at most one declaration. A name conflict in public symbols +/// is reported as an error before returning a failure. +/// +/// Note that this clones the `other` operation unlike the C++ counterpart that +/// takes ownership. +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirMergeSymbolsIntoFromClone(MlirOperation target, MlirOperation other); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 6517f8c39..f6b4532b1 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -82,6 +82,21 @@ static void populateTransformInterpreterSubmodule(py::module &m) { py::arg("payload_root"), py::arg("transform_root"), py::arg("transform_module"), py::arg("transform_options") = PyMlirTransformOptions()); + + m.def( + "copy_symbols_and_merge_into", + [](MlirOperation target, MlirOperation other) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(target)); + + MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); + if (mlirLogicalResultIsFailure(result)) { + throw py::value_error( + "Failed to merge symbols.\nDiagnostic message " + + scope.takeMessage()); + } + }, + py::arg("target"), py::arg("other")); } PYBIND11_MODULE(_mlirTransformInterpreter, m) { diff --git a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp index eb6951dc5..145455e1c 100644 --- a/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp +++ b/mlir/lib/CAPI/Dialect/TransformInterpreter.cpp @@ -15,6 +15,7 @@ #include "mlir/CAPI/IR.h" #include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" +#include "mlir/Dialect/Transform/IR/Utils.h" #include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterUtils.h" @@ -71,4 +72,12 @@ MlirLogicalResult mlirTransformApplyNamedSequence( unwrap(payload), unwrap(transformRoot), cast(unwrap(transformModule)), *unwrap(transformOptions))); } + +MlirLogicalResult mlirMergeSymbolsIntoFromClone(MlirOperation target, + MlirOperation other) { + OwningOpRef otherOwning(unwrap(other)->clone()); + LogicalResult result = transform::detail::mergeSymbolsInto( + unwrap(target), std::move(otherOwning)); + return wrap(result); +} } diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py index 6145b9922..34cdc43cb 100644 --- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py +++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py @@ -5,7 +5,6 @@ from ....ir import Operation from ...._mlir_libs import _mlirTransformInterpreter as _cextTransformInterpreter - TransformOptions = _cextTransformInterpreter.TransformOptions @@ -31,3 +30,12 @@ def apply_named_sequence( _cextTransformInterpreter.apply_named_sequence(*args) else: _cextTransformInterpreter(*args, transform_options) + + +def copy_symbols_and_merge_into(target, other): + """Copies symbols from other into target, renaming private symbols to avoid + duplicates. Raises an error if copying would lead to duplicate public + symbols.""" + _cextTransformInterpreter.copy_symbols_and_merge_into( + _unpack_operation(target), _unpack_operation(other) + ) From 7dc48bebfe57c235bce0c0f085ada9c095167535 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 17 Apr 2024 15:59:18 +0200 Subject: [PATCH 698/915] [mlir][py] Add NVGPU's `TensorMapDescriptorType` in py bindings (#88855) This PR adds NVGPU dialects' TensorMapDescriptorType in the py bindings. This is a follow-up issue from [this PR](https://github.com/llvm/llvm-project/pull/87153#discussion_r1546193095) --- mlir/include/mlir-c/Dialect/NVGPU.h | 11 ++++++ mlir/lib/Bindings/Python/DialectNVGPU.cpp | 41 +++++++++++++++++++++++ mlir/lib/CAPI/Dialect/NVGPU.cpp | 18 ++++++++++ mlir/python/CMakeLists.txt | 13 +++++++ mlir/python/mlir/dialects/nvgpu.py | 1 + 5 files changed, 84 insertions(+) create mode 100644 mlir/lib/Bindings/Python/DialectNVGPU.cpp diff --git a/mlir/include/mlir-c/Dialect/NVGPU.h b/mlir/include/mlir-c/Dialect/NVGPU.h index 580d56679..e58015a4a 100644 --- a/mlir/include/mlir-c/Dialect/NVGPU.h +++ b/mlir/include/mlir-c/Dialect/NVGPU.h @@ -11,6 +11,7 @@ #define MLIR_C_DIALECT_NVGPU_H #include "mlir-c/IR.h" +#include "mlir-c/Support.h" #ifdef __cplusplus extern "C" { @@ -18,6 +19,16 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu); +//===---------------------------------------------------------------------===// +// TensorMapDescriptorType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirNVGPUTensorMapDescriptorTypeGet( + MlirContext ctx, MlirType tensorMemrefType, int swizzle, int l2promo, + int oobFill, int interleave); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp new file mode 100644 index 000000000..341e4d55b --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -0,0 +1,41 @@ +//===--- DialectNvgpu.cpp - Pybind module for Nvgpu dialect API support ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/NVGPU.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" +#include + +namespace py = pybind11; +using namespace llvm; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +static void populateDialectNvgpuSubmodule(const pybind11::module &m) { + auto nvgpuTensorMapDescriptorType = mlir_type_subclass( + m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType); + + nvgpuTensorMapDescriptorType.def_classmethod( + "get", + [](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo, + int oobFill, int interleave, MlirContext ctx) { + return cls(mlirNVGPUTensorMapDescriptorTypeGet( + ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave)); + }, + "Gets an instance of TensorMapDescriptorType in the same context", + py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"), + py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"), + py::arg("ctx") = py::none()); +} + +PYBIND11_MODULE(_mlirDialectsNvgpu, m) { + m.doc() = "MLIR NVGPU dialect."; + + populateDialectNvgpuSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/NVGPU.cpp b/mlir/lib/CAPI/Dialect/NVGPU.cpp index 02d10954a..e6da529e1 100644 --- a/mlir/lib/CAPI/Dialect/NVGPU.cpp +++ b/mlir/lib/CAPI/Dialect/NVGPU.cpp @@ -9,5 +9,23 @@ #include "mlir-c/Dialect/NVGPU.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" +#include "mlir/IR/BuiltinTypes.h" + +using namespace mlir; +using namespace mlir::nvgpu; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(NVGPU, nvgpu, mlir::nvgpu::NVGPUDialect) + +bool mlirTypeIsANVGPUTensorMapDescriptorType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirNVGPUTensorMapDescriptorTypeGet(MlirContext ctx, + MlirType tensorMemrefType, + int swizzle, int l2promo, + int oobFill, int interleave) { + return wrap(nvgpu::TensorMapDescriptorType::get( + unwrap(ctx), cast(unwrap(tensorMemrefType)), + TensorMapSwizzleKind(swizzle), TensorMapL2PromoKind(l2promo), + TensorMapOOBKind(oobFill), TensorMapInterleaveKind(interleave))); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index c27ee688a..0a2dc0754 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -524,6 +524,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind MLIRCAPIQuant ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind + MODULE_NAME _mlirDialectsNvgpu + ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectNVGPU.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPINVGPU +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind MODULE_NAME _mlirDialectsPDL ADD_TO_PARENT MLIRPythonSources.Dialects.pdl diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py index 2f6993b76..e19bf610e 100644 --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -4,3 +4,4 @@ from ._nvgpu_ops_gen import * from ._nvgpu_enum_gen import * +from .._mlir_libs._mlirDialectsNvgpu import * From 3a5eb159875207bb5cf7cc02949fd539f5b58785 Mon Sep 17 00:00:00 2001 From: tomnatan30 <130450079+tomnatan30@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:09:31 +0100 Subject: [PATCH 699/915] [mlir][python] Fix PyOperationBase::walk not catching exception in python callback (#89225) If the python callback throws an error, the c++ code will throw a py::error_already_set that needs to be caught and handled in the c++ code . This change is inspired by the similar solution in PySymbolTable::walkSymbolTables. --- mlir/lib/Bindings/Python/IRCore.cpp | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d875f4eba..01678a971 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1255,14 +1255,31 @@ void PyOperationBase::walk( MlirWalkOrder walkOrder) { PyOperation &operation = getOperation(); operation.checkValid(); + struct UserData { + std::function callback; + bool gotException; + std::string exceptionWhat; + py::object exceptionType; + }; + UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, void *userData) { - auto *fn = - static_cast *>(userData); - return (*fn)(op); + UserData *calleeUserData = static_cast(userData); + try { + return (calleeUserData->callback)(op); + } catch (py::error_already_set &e) { + calleeUserData->gotException = true; + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); + return MlirWalkResult::MlirWalkResultInterrupt; + } }; - - mlirOperationWalk(operation, walkCallback, &callback, walkOrder); + mlirOperationWalk(operation, walkCallback, &userData, walkOrder); + if (userData.gotException) { + std::string message("Exception raised in callback: "); + message.append(userData.exceptionWhat); + throw std::runtime_error(message); + } } py::object PyOperationBase::getAsm(bool binary, From f8671e8686d4d225c74d946caeb75f82604240e5 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 18 Apr 2024 16:31:55 -0500 Subject: [PATCH 700/915] [mlir][python] add binding to `#gpu.object` (#88992) --- mlir/include/mlir-c/Dialect/GPU.h | 25 +++++++++ mlir/lib/Bindings/Python/DialectGPU.cpp | 65 +++++++++++++++++++++++ mlir/lib/CAPI/Dialect/GPU.cpp | 59 +++++++++++++++++++- mlir/python/CMakeLists.txt | 13 +++++ mlir/python/mlir/dialects/gpu/__init__.py | 1 + 5 files changed, 161 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectGPU.cpp diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h index 1a18d82c0..2adf73ddf 100644 --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -19,6 +19,31 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu); +//===---------------------------------------------------------------------===// +// ObjectAttr +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format, + MlirStringRef objectStrRef, MlirAttribute mlirObjectProps); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED uint32_t +mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED bool +mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp new file mode 100644 index 000000000..1f68bfc6f --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -0,0 +1,65 @@ +//===- DialectGPU.cpp - Pybind module for the GPU passes ------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===---------------------------------------------------------------------===// + +#include "mlir-c/Dialect/GPU.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +#include +#include + +namespace py = pybind11; +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::adaptors; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +PYBIND11_MODULE(_mlirDialectsGPU, m) { + m.doc() = "MLIR GPU Dialect"; + + //===-------------------------------------------------------------------===// + // ObjectAttr + //===-------------------------------------------------------------------===// + + mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) + .def_classmethod( + "get", + [](py::object cls, MlirAttribute target, uint32_t format, + py::bytes object, std::optional mlirObjectProps) { + py::buffer_info info(py::buffer(object).request()); + MlirStringRef objectStrRef = + mlirStringRefCreate(static_cast(info.ptr), info.size); + return cls(mlirGPUObjectAttrGet( + mlirAttributeGetContext(target), target, format, objectStrRef, + mlirObjectProps.has_value() ? *mlirObjectProps + : MlirAttribute{nullptr})); + }, + "cls"_a, "target"_a, "format"_a, "object"_a, + "properties"_a = py::none(), "Gets a gpu.object from parameters.") + .def_property_readonly( + "target", + [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) + .def_property_readonly( + "format", + [](MlirAttribute self) { return mlirGPUObjectAttrGetFormat(self); }) + .def_property_readonly( + "object", + [](MlirAttribute self) { + MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); + return py::bytes(stringRef.data, stringRef.length); + }) + .def_property_readonly("properties", [](MlirAttribute self) { + if (mlirGPUObjectAttrHasProperties(self)) + return py::cast(mlirGPUObjectAttrGetProperties(self)); + return py::none().cast(); + }); +} diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp index cd58f0e24..e471e8cd9 100644 --- a/mlir/lib/CAPI/Dialect/GPU.cpp +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -1,4 +1,4 @@ -//===- GPUc.cpp - C Interface for GPU dialect ----------------------------===// +//===- GPU.cpp - C Interface for GPU dialect ------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,5 +9,60 @@ #include "mlir-c/Dialect/GPU.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "llvm/Support/Casting.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, mlir::gpu::GPUDialect) +using namespace mlir; + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect) + +//===---------------------------------------------------------------------===// +// ObjectAttr +//===---------------------------------------------------------------------===// + +bool mlirAttributeIsAGPUObjectAttr(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, + uint32_t format, MlirStringRef objectStrRef, + MlirAttribute mlirObjectProps) { + MLIRContext *ctx = unwrap(mlirCtx); + llvm::StringRef object = unwrap(objectStrRef); + DictionaryAttr objectProps; + if (mlirObjectProps.ptr != nullptr) + objectProps = llvm::cast(unwrap(mlirObjectProps)); + return wrap(gpu::ObjectAttr::get(ctx, unwrap(target), + static_cast(format), + StringAttr::get(ctx, object), objectProps)); +} + +MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getTarget()); +} + +uint32_t mlirGPUObjectAttrGetFormat(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return static_cast(objectAttr.getFormat()); +} + +MlirStringRef mlirGPUObjectAttrGetObject(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + llvm::StringRef object = objectAttr.getObject(); + return mlirStringRefCreate(object.data(), object.size()); +} + +bool mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return objectAttr.getProperties() != nullptr; +} + +MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getProperties()); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 0a2dc0754..ddd5c6b5a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -498,6 +498,19 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MLIRCAPILinalg ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind + MODULE_NAME _mlirDialectsGPU + ADD_TO_PARENT MLIRPythonSources.Dialects.gpu + ROOT_DIR "${PYTHON_SOURCE_DIR}" + SOURCES + DialectGPU.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPIGPU +) + declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind MODULE_NAME _mlirDialectsLLVM ADD_TO_PARENT MLIRPythonSources.Dialects.llvm diff --git a/mlir/python/mlir/dialects/gpu/__init__.py b/mlir/python/mlir/dialects/gpu/__init__.py index 033386b0f..4cd80aa8b 100644 --- a/mlir/python/mlir/dialects/gpu/__init__.py +++ b/mlir/python/mlir/dialects/gpu/__init__.py @@ -4,3 +4,4 @@ from .._gpu_ops_gen import * from .._gpu_enum_gen import * +from ..._mlir_libs._mlirDialectsGPU import * From 47e076bc4c48ad73ca8caa0e085791692c168ecf Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 19 Apr 2024 15:58:27 +0200 Subject: [PATCH 701/915] Switch member calls to `isa/dyn_cast/cast/...` to free function calls. (#89356) This change cleans up call sites. Next step is to mark the member functions deprecated. See https://mlir.llvm.org/deprecation and https://discourse.llvm.org/t/preferred-casting-style-going-forward. --- mlir/lib/CAPI/Dialect/LLVM.cpp | 6 +++--- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 4669c40f8..21c66f38a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -135,7 +135,7 @@ MlirAttribute mlirLLVMDIExpressionAttrGet(MlirContext ctx, intptr_t nOperations, unwrap(ctx), llvm::map_to_vector( unwrapList(nOperations, operations, attrStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDINullTypeAttrGet(MlirContext ctx) { @@ -165,7 +165,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute @@ -259,7 +259,7 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, return wrap(DISubroutineTypeAttr::get( unwrap(ctx), callingConvention, llvm::map_to_vector(unwrapList(nTypes, types, attrStorage), - [](Attribute a) { return a.cast(); }))); + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDISubprogramAttrGet( diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index e1a5d8258..c94c07014 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -311,11 +311,11 @@ MlirType mlirVectorTypeGetScalableChecked(MlirLocation loc, intptr_t rank, } bool mlirVectorTypeIsScalable(MlirType type) { - return unwrap(type).cast().isScalable(); + return cast(unwrap(type)).isScalable(); } bool mlirVectorTypeIsDimScalable(MlirType type, intptr_t dim) { - return unwrap(type).cast().getScalableDims()[dim]; + return cast(unwrap(type)).getScalableDims()[dim]; } //===----------------------------------------------------------------------===// From 7411446757cf6205e21b95352817537a8ec1528f Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 19 Apr 2024 16:52:04 -0500 Subject: [PATCH 702/915] [mlir][python] fix memref._is_constant_int_like (#89447) --- mlir/python/mlir/dialects/memref.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/memref.py b/mlir/python/mlir/dialects/memref.py index a3d783415..bc9a3a527 100644 --- a/mlir/python/mlir/dialects/memref.py +++ b/mlir/python/mlir/dialects/memref.py @@ -8,12 +8,13 @@ from ._memref_ops_gen import * from ._ods_common import _dispatch_mixed_values, MixedValues from .arith import ConstantOp, _is_integer_like_type -from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType +from ..ir import Value, MemRefType, StridedLayoutAttr, ShapedType, Operation def _is_constant_int_like(i): return ( isinstance(i, Value) + and isinstance(i.owner, Operation) and isinstance(i.owner.opview, ConstantOp) and _is_integer_like_type(i.type) ) From e11de093c303118e1f0352357f6f2bc69daf3430 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni <11399+adk9@users.noreply.github.com> Date: Sat, 20 Apr 2024 18:49:39 -0700 Subject: [PATCH 703/915] [mlir][python] Fix generation of Python bindings for `async` dialect (#75960) The Python bindings generated for "async" dialect didn't include any of the "async" dialect ops. This PR fixes issues with generation of Python bindings for "async" dialect and adds a test case to use them. --- mlir/python/CMakeLists.txt | 4 ++-- mlir/python/mlir/dialects/async_dialect/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ddd5c6b5a..a6c78880c 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -79,7 +79,7 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/AsyncOps.td SOURCES_GLOB dialects/async_dialect/*.py - DIALECT_NAME async_dialect) + DIALECT_NAME async) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -591,7 +591,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses - ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect + ADD_TO_PARENT MLIRPythonSources.Dialects.async ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES AsyncPasses.cpp diff --git a/mlir/python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py index dcf9d6cb2..6a5ecfc20 100644 --- a/mlir/python/mlir/dialects/async_dialect/__init__.py +++ b/mlir/python/mlir/dialects/async_dialect/__init__.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._async_dialect_ops_gen import * +from .._async_ops_gen import * From f110b485c5846b441b85a5de218483ca300f052b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 24 Apr 2024 07:43:05 -0500 Subject: [PATCH 704/915] [mlir][python] extend LLVM bindings (#89797) Add bindings for LLVM pointer type. --- mlir/include/mlir-c/Dialect/LLVM.h | 7 ++++ mlir/lib/Bindings/Python/DialectLLVM.cpp | 43 +++++++++++++++++++----- mlir/lib/CAPI/Dialect/LLVM.cpp | 8 +++++ mlir/python/mlir/dialects/LLVMOps.td | 1 + mlir/python/mlir/dialects/llvm.py | 8 +++++ 5 files changed, 59 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index bd9b7dd26..b3e64bd68 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -23,6 +23,13 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(LLVM, llvm); MLIR_CAPI_EXPORTED MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace); +/// Returns `true` if the type is an LLVM dialect pointer type. +MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMPointerType(MlirType type); + +/// Returns address space of llvm.ptr +MLIR_CAPI_EXPORTED unsigned +mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType); + /// Creates an llmv.void type. MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx); diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 843707751..42a4c8c07 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -19,6 +19,11 @@ using namespace mlir::python; using namespace mlir::python::adaptors; void populateDialectLLVMSubmodule(const pybind11::module &m) { + + //===--------------------------------------------------------------------===// + // StructType + //===--------------------------------------------------------------------===// + auto llvmStructType = mlir_type_subclass(m, "StructType", mlirTypeIsALLVMStructType); @@ -35,8 +40,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { } return cls(type); }, - py::arg("cls"), py::arg("elements"), py::kw_only(), - py::arg("packed") = false, py::arg("loc") = py::none()); + "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false, + "loc"_a = py::none()); llvmStructType.def_classmethod( "get_identified", @@ -44,8 +49,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { return cls(mlirLLVMStructTypeIdentifiedGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - py::arg("cls"), py::arg("name"), py::kw_only(), - py::arg("context") = py::none()); + "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none()); llvmStructType.def_classmethod( "get_opaque", @@ -53,7 +57,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { return cls(mlirLLVMStructTypeOpaqueGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - py::arg("cls"), py::arg("name"), py::arg("context") = py::none()); + "cls"_a, "name"_a, "context"_a = py::none()); llvmStructType.def( "set_body", @@ -65,7 +69,7 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { "Struct body already set to different content."); } }, - py::arg("elements"), py::kw_only(), py::arg("packed") = false); + "elements"_a, py::kw_only(), "packed"_a = false); llvmStructType.def_classmethod( "new_identified", @@ -75,8 +79,8 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { ctx, mlirStringRefCreate(name.data(), name.length()), elements.size(), elements.data(), packed)); }, - py::arg("cls"), py::arg("name"), py::arg("elements"), py::kw_only(), - py::arg("packed") = false, py::arg("context") = py::none()); + "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false, + "context"_a = py::none()); llvmStructType.def_property_readonly( "name", [](MlirType type) -> std::optional { @@ -105,6 +109,29 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { llvmStructType.def_property_readonly( "opaque", [](MlirType type) { return mlirLLVMStructTypeIsOpaque(type); }); + + //===--------------------------------------------------------------------===// + // PointerType + //===--------------------------------------------------------------------===// + + mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) + .def_classmethod( + "get", + [](py::object cls, std::optional addressSpace, + MlirContext context) { + CollectDiagnosticsToStringScope scope(context); + MlirType type = mlirLLVMPointerTypeGet( + context, addressSpace.has_value() ? *addressSpace : 0); + if (mlirTypeIsNull(type)) { + throw py::value_error(scope.takeMessage()); + } + return cls(type); + }, + "cls"_a, "address_space"_a = py::none(), py::kw_only(), + "context"_a = py::none()) + .def_property_readonly("address_space", [](MlirType type) { + return mlirLLVMPointerTypeGetAddressSpace(type); + }); } PYBIND11_MODULE(_mlirDialectsLLVM, m) { diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 21c66f38a..108ebe536 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -27,6 +27,14 @@ MlirType mlirLLVMPointerTypeGet(MlirContext ctx, unsigned addressSpace) { return wrap(LLVMPointerType::get(unwrap(ctx), addressSpace)); } +bool mlirTypeIsALLVMPointerType(MlirType type) { + return isa(unwrap(type)); +} + +unsigned mlirLLVMPointerTypeGetAddressSpace(MlirType pointerType) { + return cast(unwrap(pointerType)).getAddressSpace(); +} + MlirType mlirLLVMVoidTypeGet(MlirContext ctx) { return wrap(LLVMVoidType::get(unwrap(ctx))); } diff --git a/mlir/python/mlir/dialects/LLVMOps.td b/mlir/python/mlir/dialects/LLVMOps.td index dcf2f4245..30f047f21 100644 --- a/mlir/python/mlir/dialects/LLVMOps.td +++ b/mlir/python/mlir/dialects/LLVMOps.td @@ -10,5 +10,6 @@ #define PYTHON_BINDINGS_LLVM_OPS include "mlir/Dialect/LLVMIR/LLVMOps.td" +include "mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td" #endif diff --git a/mlir/python/mlir/dialects/llvm.py b/mlir/python/mlir/dialects/llvm.py index 8aa16e4a2..941a58496 100644 --- a/mlir/python/mlir/dialects/llvm.py +++ b/mlir/python/mlir/dialects/llvm.py @@ -5,3 +5,11 @@ from ._llvm_ops_gen import * from ._llvm_enum_gen import * from .._mlir_libs._mlirDialectsLLVM import * +from ..ir import Value +from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results + + +def mlir_constant(value, *, loc=None, ip=None) -> Value: + return _get_op_result_or_op_results( + ConstantOp(res=value.type, value=value, loc=loc, ip=ip) + ) From c777aaafd0b27967ac8f593888c5811efb3628ed Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Wed, 24 Apr 2024 19:40:53 +0200 Subject: [PATCH 705/915] [mlir][py] fix option passing in transform interpreter (#89922) There was a typo in dispatch trampoline. --- mlir/python/mlir/dialects/transform/interpreter/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/transform/interpreter/__init__.py b/mlir/python/mlir/dialects/transform/interpreter/__init__.py index 34cdc43cb..e69aa9630 100644 --- a/mlir/python/mlir/dialects/transform/interpreter/__init__.py +++ b/mlir/python/mlir/dialects/transform/interpreter/__init__.py @@ -29,7 +29,7 @@ def apply_named_sequence( if transform_options is None: _cextTransformInterpreter.apply_named_sequence(*args) else: - _cextTransformInterpreter(*args, transform_options) + _cextTransformInterpreter.apply_named_sequence(*args, transform_options) def copy_symbols_and_merge_into(target, other): From 116b295713ba4989f97745a009eba3d0e75176fc Mon Sep 17 00:00:00 2001 From: Yinying Li Date: Wed, 24 Apr 2024 16:20:25 -0700 Subject: [PATCH 706/915] [mlir][sparse] Enable explicit and implicit value in sparse encoding (#88975) 1. Explicit value means the non-zero value in a sparse tensor. If explicitVal is set, then all the non-zero values in the tensor have the same explicit value. The default value Attribute() indicates that it is not set. 2. Implicit value means the "zero" value in a sparse tensor. If implicitVal is set, then the "zero" value in the tensor is equal to the implicit value. For now, we only support `0` as the implicit value but it could be extended in the future. The default value Attribute() indicates that the implicit value is `0` (same type as the tensor element type). Example: ``` #CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64, explicitVal = 1 : i64, implicitVal = 0 : i64 }> ``` Note: this PR tests that implicitVal could be set to other values as well. The following PR will add verifier and reject any value that's not zero for implicitVal. --- mlir/include/mlir-c/Dialect/SparseTensor.h | 11 +++++++- .../Bindings/Python/DialectSparseTensor.cpp | 27 ++++++++++++++++--- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 26 ++++++++++++------ 3 files changed, 52 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 52ca7ba8a..125469f57 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -53,7 +53,8 @@ mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr); MLIR_CAPI_EXPORTED MlirAttribute mlirSparseTensorEncodingAttrGet( MlirContext ctx, intptr_t lvlRank, MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, - MlirAffineMap lvlTodim, int posWidth, int crdWidth); + MlirAffineMap lvlTodim, int posWidth, int crdWidth, + MlirAttribute explicitVal, MlirAttribute implicitVal); /// Returns the level-rank of the `sparse_tensor.encoding` attribute. MLIR_CAPI_EXPORTED intptr_t @@ -85,6 +86,14 @@ mlirSparseTensorEncodingAttrGetPosWidth(MlirAttribute attr); MLIR_CAPI_EXPORTED int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr); +/// Returns the explicit value of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr); + +/// Returns the implicit value of the `sparse_tensor.encoding` attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr); + MLIR_CAPI_EXPORTED unsigned mlirSparseTensorEncodingAttrGetStructuredN(MlirSparseTensorLevelType lvlType); diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 171faf9e0..584981cfe 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -42,16 +42,19 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { [](py::object cls, std::vector lvlTypes, std::optional dimToLvl, std::optional lvlToDim, int posWidth, int crdWidth, - MlirContext context) { + std::optional explicitVal, + std::optional implicitVal, MlirContext context) { return cls(mlirSparseTensorEncodingAttrGet( context, lvlTypes.size(), lvlTypes.data(), dimToLvl ? *dimToLvl : MlirAffineMap{nullptr}, lvlToDim ? *lvlToDim : MlirAffineMap{nullptr}, posWidth, - crdWidth)); + crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr}, + implicitVal ? *implicitVal : MlirAttribute{nullptr})); }, py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"), py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"), - py::arg("context") = py::none(), + py::arg("explicit_val") = py::none(), + py::arg("implicit_val") = py::none(), py::arg("context") = py::none(), "Gets a sparse_tensor.encoding from parameters.") .def_classmethod( "build_level_type", @@ -97,6 +100,24 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { mlirSparseTensorEncodingAttrGetPosWidth) .def_property_readonly("crd_width", mlirSparseTensorEncodingAttrGetCrdWidth) + .def_property_readonly( + "explicit_val", + [](MlirAttribute self) -> std::optional { + MlirAttribute ret = + mlirSparseTensorEncodingAttrGetExplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return ret; + }) + .def_property_readonly( + "implicit_val", + [](MlirAttribute self) -> std::optional { + MlirAttribute ret = + mlirSparseTensorEncodingAttrGetImplicitVal(self); + if (mlirAttributeIsNull(ret)) + return {}; + return ret; + }) .def_property_readonly( "structured_n", [](MlirAttribute self) -> unsigned { diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 3ae06f220..19171d64d 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -44,18 +44,20 @@ bool mlirAttributeIsASparseTensorEncodingAttr(MlirAttribute attr) { return isa(unwrap(attr)); } -MlirAttribute -mlirSparseTensorEncodingAttrGet(MlirContext ctx, intptr_t lvlRank, - MlirSparseTensorLevelType const *lvlTypes, - MlirAffineMap dimToLvl, MlirAffineMap lvlToDim, - int posWidth, int crdWidth) { +MlirAttribute mlirSparseTensorEncodingAttrGet( + MlirContext ctx, intptr_t lvlRank, + MlirSparseTensorLevelType const *lvlTypes, MlirAffineMap dimToLvl, + MlirAffineMap lvlToDim, int posWidth, int crdWidth, + MlirAttribute explicitVal, MlirAttribute implicitVal) { SmallVector cppLvlTypes; + cppLvlTypes.reserve(lvlRank); for (intptr_t l = 0; l < lvlRank; ++l) cppLvlTypes.push_back(static_cast(lvlTypes[l])); - return wrap(SparseTensorEncodingAttr::get(unwrap(ctx), cppLvlTypes, - unwrap(dimToLvl), unwrap(lvlToDim), - posWidth, crdWidth)); + + return wrap(SparseTensorEncodingAttr::get( + unwrap(ctx), cppLvlTypes, unwrap(dimToLvl), unwrap(lvlToDim), posWidth, + crdWidth, unwrap(explicitVal), unwrap(implicitVal))); } MlirAffineMap mlirSparseTensorEncodingAttrGetDimToLvl(MlirAttribute attr) { @@ -91,6 +93,14 @@ int mlirSparseTensorEncodingAttrGetCrdWidth(MlirAttribute attr) { return cast(unwrap(attr)).getCrdWidth(); } +MlirAttribute mlirSparseTensorEncodingAttrGetExplicitVal(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getExplicitVal()); +} + +MlirAttribute mlirSparseTensorEncodingAttrGetImplicitVal(MlirAttribute attr) { + return wrap(cast(unwrap(attr)).getImplicitVal()); +} + MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( enum MlirSparseTensorLevelFormat lvlFmt, const enum MlirSparseTensorLevelPropertyNondefault *properties, From 7b0ff626ea60fa02726c1316038939716abca339 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Sun, 28 Apr 2024 15:25:24 +0100 Subject: [PATCH 707/915] [MLIR][Linalg] More Linalg named ops (#90236) Adding `min` that was already implemented but not exposed. Adding a few additional unary ops: * Reciprocal as `arith.div(1,arg)` * Round as `math.round(arg)` * Sqrt as `math.sqrt(arg)` * Rsqrt as `math.rsqrt(arg)` * Square as `math.powf(arg, 2)` * TanH as `math.tanh(arg)` All with the agreed semantics at the round table: no implicit broadcast/type cast. --- .../linalg/opdsl/lang/comprehension.py | 5 ++ .../linalg/opdsl/ops/core_named_ops.py | 81 ++++++++++++++++++- 2 files changed, 85 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 23d6d26b7..f7bc81bd2 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -291,6 +291,11 @@ class UnaryFn: ceil = UnaryFnType("ceil") floor = UnaryFnType("floor") negf = UnaryFnType("negf") + round = UnaryFnType("round") + sqrt = UnaryFnType("sqrt") + rsqrt = UnaryFnType("rsqrt") + square = UnaryFnType("square") + tanh = UnaryFnType("tanh") class BinaryFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 5b05364f6..2c8864be1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -108,6 +108,66 @@ def negf( O[None] = UnaryFn.negf(I[None]) +@linalg_structured_op +def round( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies round(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.round(I[None]) + + +@linalg_structured_op +def sqrt( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies sqrt(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.sqrt(I[None]) + + +@linalg_structured_op +def rsqrt( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies rsqrt(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.rsqrt(I[None]) + + +@linalg_structured_op +def square( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies square(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.square(I[None]) + + +@linalg_structured_op +def tanh( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies tanh(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.tanh(I[None]) + + @linalg_structured_op def elemwise_binary( lhs=TensorDef(T1), @@ -233,12 +293,31 @@ def max( This means reduction/broadcast/element cast semantics is explicit. Further passes can take that into account when lowering this code. For example, - a `linalg.broadcast` + `linalg.div` sequence can be lowered to a + a `linalg.broadcast` + `linalg.max` sequence can be lowered to a `linalg.generic` with different affine maps for the two operands. """ O[None] = BinaryFn.max_signed(lhs[None], rhs[None]) +@linalg_structured_op +def min( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the min (signed) between two inputs, elementwise. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.min` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.min_signed(lhs[None], rhs[None]) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From fba13cdf50c797a3de7258fb7149ec609ad67056 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Mon, 29 Apr 2024 15:42:35 +0100 Subject: [PATCH 708/915] [MLIR][Linalg] Left over Linalg named ops from previous PR (#90405) Adding `erf` as unary and `powf` as binary. Same as `max(arg, 0.0)` for `ReLU`, `powf(arg, const)` can be either a generic (with broadcast) or a pair (`linalg.broadcast + linalg.powf`) and then lowered "correctly". Either way, the lower dialects need to know what kind of broadcast anyway, so no materialization of the constant tensors should remain. I want to flush the easy ones before we start working on type cast & softmax. --- .../linalg/opdsl/lang/comprehension.py | 2 ++ .../linalg/opdsl/ops/core_named_ops.py | 33 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index f7bc81bd2..bb43ebf2b 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -296,6 +296,7 @@ class UnaryFn: rsqrt = UnaryFnType("rsqrt") square = UnaryFnType("square") tanh = UnaryFnType("tanh") + erf = UnaryFnType("erf") class BinaryFnType: @@ -335,6 +336,7 @@ class BinaryFn: min_signed = BinaryFnType("min_signed") max_unsigned = BinaryFnType("max_unsigned") min_unsigned = BinaryFnType("min_unsigned") + powf = BinaryFnType("powf") class TypeFnType: diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 2c8864be1..ca2bb0c5f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -168,6 +168,18 @@ def tanh( O[None] = UnaryFn.tanh(I[None]) +@linalg_structured_op +def erf( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies erf(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.erf(I[None]) + + @linalg_structured_op def elemwise_binary( lhs=TensorDef(T1), @@ -318,6 +330,27 @@ def min( O[None] = BinaryFn.min_signed(lhs[None], rhs[None]) +@linalg_structured_op +def powf( + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Takes the powf(lhs, rhs) between two inputs, elementwise. For powf(arg, 2) use `linalg.square`. + + Only applies to floating point values. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.powf` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = BinaryFn.powf(lhs[None], rhs[None]) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From cb52be598ae03465f1750ed36db32ac05a4e0750 Mon Sep 17 00:00:00 2001 From: srcarroll <50210727+srcarroll@users.noreply.github.com> Date: Sat, 4 May 2024 17:34:40 -0500 Subject: [PATCH 709/915] [mlir][transform] Add support for transform.param pad multiples in `PadOp` (#90755) This patch modifies the definition of `PadOp` to take transform params and handles for the `pad_to_multiple_of` operand. --------- Co-authored-by: Oleksandr "Alex" Zinenko --- mlir/python/mlir/dialects/transform/structured.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index d7b41c0bd..2c49ef096 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -374,9 +374,9 @@ def __init__( self, target: Union[Operation, OpView, Value], *, + pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None, padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, - pad_to_multiple_of: OptionalIntList = None, pack_paddings: OptionalIntList = None, transpose_paddings: Optional[ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] @@ -385,6 +385,16 @@ def __init__( loc=None, ip=None, ): + if pad_to_multiple_of is None: + dynamic_pad_to_multiple_of = [] + static_pad_to_multiple_of = None + else: + ( + dynamic_pad_to_multiple_of, + static_pad_to_multiple_of, + _, + ) = _dispatch_dynamic_index_list(pad_to_multiple_of) + transpose_paddings = _get_int_array_array_attr(transpose_paddings) any_op_type = transform.AnyOpType.get() @@ -393,9 +403,10 @@ def __init__( any_op_type, any_op_type, target, + pad_to_multiple_of=dynamic_pad_to_multiple_of, padding_values=padding_values, padding_dimensions=padding_dimensions, - pad_to_multiple_of=pad_to_multiple_of, + static_pad_to_multiple_of=static_pad_to_multiple_of, pack_paddings=pack_paddings, transpose_paddings=transpose_paddings, copy_back_op=copy_back_op, From 21c5d482821d2496bc584a178ca6b7a45af2f6f6 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 6 May 2024 20:08:47 +0800 Subject: [PATCH 710/915] [MLIR] fix _f64ElementsAttr in ir.py (#91176) --- mlir/python/mlir/ir.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index eb7f035fe..80c965b2d 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -274,7 +274,7 @@ def _memref_type_attr(x, context): @register_attribute_builder("F64ElementsAttr") def _f64ElementsAttr(x, context): return DenseElementsAttr.get( - np.array(x, dtype=np.int64), + np.array(x, dtype=np.float64), type=F64Type.get(context=context), context=context, ) From 3e42dc6650b48ba60ddf6befee682847ad982eb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Th=C3=A9o=20Degioanni?= Date: Sat, 11 May 2024 19:45:34 +0200 Subject: [PATCH 711/915] [MLIR] Add IRDL dialect loading to C API (#91852) Being able to add custom dialects is one of the big missing pieces of the C API. This change should make it achievable via IRDL. Hopefully this should open custom dialect definition to non-C++ users of MLIR. --- mlir/include/mlir-c/Dialect/IRDL.h | 29 ++++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/IRDL.cpp | 18 +++++++++++++++++ 3 files changed, 56 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/IRDL.h create mode 100644 mlir/lib/CAPI/Dialect/IRDL.cpp diff --git a/mlir/include/mlir-c/Dialect/IRDL.h b/mlir/include/mlir-c/Dialect/IRDL.h new file mode 100644 index 000000000..c4d6ffd98 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/IRDL.h @@ -0,0 +1,29 @@ +//===-- mlir-c/Dialect/IRDL.h - C API for IRDL --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_IRDL_H +#define MLIR_C_DIALECT_IRDL_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(IRDL, irdl); + +/// Loads all IRDL dialects in the provided module, registering the dialects in +/// the module's associated context. +MLIR_CAPI_EXPORTED MlirLogicalResult mlirLoadIRDLDialects(MlirModule module); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_IRDL_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 58b873904..4e141b60f 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -72,6 +72,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIGPU MLIRPass ) +add_mlir_upstream_c_api_library(MLIRCAPIIRDL + IRDL.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRIRDL +) + add_mlir_upstream_c_api_library(MLIRCAPILLVM LLVM.cpp diff --git a/mlir/lib/CAPI/Dialect/IRDL.cpp b/mlir/lib/CAPI/Dialect/IRDL.cpp new file mode 100644 index 000000000..cb9dc8ceb --- /dev/null +++ b/mlir/lib/CAPI/Dialect/IRDL.cpp @@ -0,0 +1,18 @@ +//===- IRDL.cpp - C Interface for IRDL dialect ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/IRDL.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/IRDL/IR/IRDL.h" +#include "mlir/Dialect/IRDL/IRDLLoading.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(IRDL, irdl, mlir::irdl::IRDLDialect) + +MlirLogicalResult mlirLoadIRDLDialects(MlirModule module) { + return wrap(mlir::irdl::loadDialects(unwrap(module))); +} From 876a5c53b316c9173b91b905b3137fbd0b6d3915 Mon Sep 17 00:00:00 2001 From: tyb0807 Date: Mon, 13 May 2024 09:08:04 +0200 Subject: [PATCH 712/915] [NFC] Make NVGPU casing consistent (#91903) --- mlir/lib/Bindings/Python/DialectNVGPU.cpp | 8 ++++---- mlir/python/CMakeLists.txt | 2 +- mlir/python/mlir/dialects/nvgpu.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp index 341e4d55b..754e0a75b 100644 --- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -1,4 +1,4 @@ -//===--- DialectNvgpu.cpp - Pybind module for Nvgpu dialect API support ---===// +//===--- DialectNVGPU.cpp - Pybind module for NVGPU dialect API support ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -17,7 +17,7 @@ using namespace mlir; using namespace mlir::python; using namespace mlir::python::adaptors; -static void populateDialectNvgpuSubmodule(const pybind11::module &m) { +static void populateDialectNVGPUSubmodule(const pybind11::module &m) { auto nvgpuTensorMapDescriptorType = mlir_type_subclass( m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType); @@ -34,8 +34,8 @@ static void populateDialectNvgpuSubmodule(const pybind11::module &m) { py::arg("ctx") = py::none()); } -PYBIND11_MODULE(_mlirDialectsNvgpu, m) { +PYBIND11_MODULE(_mlirDialectsNVGPU, m) { m.doc() = "MLIR NVGPU dialect."; - populateDialectNvgpuSubmodule(m); + populateDialectNVGPUSubmodule(m); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index a6c78880c..d8f2d1989 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -538,7 +538,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind ) declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind - MODULE_NAME _mlirDialectsNvgpu + MODULE_NAME _mlirDialectsNVGPU ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES diff --git a/mlir/python/mlir/dialects/nvgpu.py b/mlir/python/mlir/dialects/nvgpu.py index e19bf610e..d6a54f277 100644 --- a/mlir/python/mlir/dialects/nvgpu.py +++ b/mlir/python/mlir/dialects/nvgpu.py @@ -4,4 +4,4 @@ from ._nvgpu_ops_gen import * from ._nvgpu_enum_gen import * -from .._mlir_libs._mlirDialectsNvgpu import * +from .._mlir_libs._mlirDialectsNVGPU import * From c151893be33270a3884c27d260e5b8d67354d81b Mon Sep 17 00:00:00 2001 From: Petr Kurapov Date: Tue, 14 May 2024 11:50:35 +0200 Subject: [PATCH 713/915] [MLIR][Linalg] Ternary Op & Linalg select (#91461) Following #90236, adding `select` to linalg as `arith.select`. No implicit type casting. OpDSL doesn't expose a type restriction for bool, but I saw no reason in adding it (put a separate symbolic type and check the semantics in the builder). --------- Co-authored-by: Renato Golin Co-authored-by: Maksim Levental --- .../linalg/opdsl/lang/comprehension.py | 61 ++++++++++++++++++- .../dialects/linalg/opdsl/lang/emitter.py | 7 +++ .../linalg/opdsl/ops/core_named_ops.py | 20 ++++++ 3 files changed, 86 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index bb43ebf2b..1a198fc5e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -262,7 +262,8 @@ def __repr__(self): class FunctionKind(Enum): UNARY = 0 BINARY = 1 - TYPE = 2 + TERNARY = 2 + TYPE = 3 class UnaryFnType: @@ -339,6 +340,33 @@ class BinaryFn: powf = BinaryFnType("powf") +class TernaryFnType: + """Ternary function. + + A ternary function takes three tensor expressions and returns the + function evaluation result. + """ + + def __init__(self, fn_name: str): + self.fn_name = fn_name + + def __call__( + self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression + ) -> "TensorFn": + return TensorFn( + FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2] + ) + + def __repr__(self): + return f"{self.fn_name}" + + +class TernaryFn: + """Ternary function namespace.""" + + select = TernaryFnType("select") + + class TypeFnType: """Type conversion function. @@ -437,7 +465,8 @@ class OperandKind(Enum): INDEX_ATTR = 3 UNARY_FN_ATTR = 4 BINARY_FN_ATTR = 5 - TYPE_FN_ATTR = 6 + TERNARY_FN_ATTR = 6 + TYPE_FN_ATTR = 7 class OperandDef: @@ -489,6 +518,7 @@ def is_attribute(self) -> bool: self.kind == OperandKind.INDEX_ATTR or self.kind == OperandKind.UNARY_FN_ATTR or self.kind == OperandKind.BINARY_FN_ATTR + or self.kind == OperandKind.TERNARY_FN_ATTR or self.kind == OperandKind.TYPE_FN_ATTR ) @@ -670,6 +700,33 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: return ReduceFnUse(None, self, *reduce_dims) +class TernaryFnAttrDef: + """Ternary function attribute definition. + + Ternary function attributes provide a way to make the arithmetic computation + parametrizable. Every attribute specifies a default Ternary function + that may be overwritten at operation instantiation time. + """ + + def __init__(self, default: "TernaryFnType"): + if not isinstance(default, TernaryFnType): + raise ValueError( + f"TernaryFnAttrDef requires default of type TernaryFnType " + f"but got {default}" + ) + self.operand_def = OperandDef( + OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name + ) + + def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn: + return TensorFn( + FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1] + ) + + def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse: + return ReduceFnUse(None, self, *reduce_dims) + + class TypeFnAttrDef: """Type conversion function attribute definition. diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index f91fc8b71..845b533db 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -60,6 +60,7 @@ def prepare_common_structured_op( in [ OperandKind.UNARY_FN_ATTR, OperandKind.BINARY_FN_ATTR, + OperandKind.TERNARY_FN_ATTR, OperandKind.TYPE_FN_ATTR, ] ] @@ -180,6 +181,12 @@ def prepare_common_structured_op( f"Attribute {fn_attr.name} needs to be of type " f"BinaryFnType but got {type(attr_val)}" ) + elif attr_kind == OperandKind.TERNARY_FN_ATTR: + if not isinstance(fn, TernaryFnType): + raise ValueError( + f"Attribute {fn_attr.name} needs to be of type " + f"TernaryFnType but got {type(attr_val)}" + ) else: if not isinstance(fn, TypeFnType): raise ValueError( diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index ca2bb0c5f..d73428a0f 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -351,6 +351,26 @@ def powf( O[None] = BinaryFn.powf(lhs[None], rhs[None]) +@linalg_structured_op +def select( + cond=TensorDef(U), + lhs=TensorDef(T1), + rhs=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Chooses one value based on a binary condition supplied as its first operand. + + The shapes and element types must be identical. The appropriate casts, + broadcasts and reductions should be done previously to calling this op. + + This means reduction/broadcast/element cast semantics is explicit. Further + passes can take that into account when lowering this code. For example, + a `linalg.broadcast` + `linalg.select` sequence can be lowered to a + `linalg.generic` with different affine maps for the two operands. + """ + O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) + + @linalg_structured_op def matmul( A=TensorDef(T1, S.M, S.K), From 69a37c3461cbc296f9e5adad1fd8b4e823985837 Mon Sep 17 00:00:00 2001 From: William G Hatch Date: Thu, 16 May 2024 00:24:56 -0600 Subject: [PATCH 714/915] [MLIR][LLVM] add dwarfAddressSpace to DIDerivedType (#92043) This field is present in LLVM, but was missing from the MLIR wrapper type. This addition allows MLIR languages to add proper DWARF info for GPU programs. --- mlir/include/mlir-c/Dialect/LLVM.h | 9 +++++++-- mlir/lib/CAPI/Dialect/LLVM.cpp | 14 ++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index b3e64bd68..446d36434 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -241,11 +241,16 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements); -/// Creates a LLVM DIDerivedType attribute. +/// Creates a LLVM DIDerivedType attribute. Note that `dwarfAddressSpace` is an +/// optional field, where `MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL` indicates null +/// and non-negative values indicate a value present. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute name, MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, - uint64_t offsetInBits, MlirAttribute extraData); + uint64_t offsetInBits, int64_t dwarfAddressSpace, MlirAttribute extraData); + +/// Constant to represent std::nullopt for dwarfAddressSpace to omit the field. +#define MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL -1 /// Gets the base type from a LLVM DIDerivedType attribute. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 108ebe536..9a2874963 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -176,15 +176,17 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( [](Attribute a) { return cast(a); }))); } -MlirAttribute -mlirLLVMDIDerivedTypeAttrGet(MlirContext ctx, unsigned int tag, - MlirAttribute name, MlirAttribute baseType, - uint64_t sizeInBits, uint32_t alignInBits, - uint64_t offsetInBits, MlirAttribute extraData) { +MlirAttribute mlirLLVMDIDerivedTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, + MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, + uint64_t offsetInBits, int64_t dwarfAddressSpace, MlirAttribute extraData) { + std::optional addressSpace = std::nullopt; + if (dwarfAddressSpace >= 0) + addressSpace = (unsigned)dwarfAddressSpace; return wrap(DIDerivedTypeAttr::get( unwrap(ctx), tag, cast(unwrap(name)), cast(unwrap(baseType)), sizeInBits, alignInBits, offsetInBits, - cast(unwrap(extraData)))); + addressSpace, cast(unwrap(extraData)))); } MlirAttribute From 06eb9c7d28e6760cfed8cdd194ab2de6840b123b Mon Sep 17 00:00:00 2001 From: pranavm-nvidia <49246958+pranavm-nvidia@users.noreply.github.com> Date: Wed, 22 May 2024 03:44:22 -0700 Subject: [PATCH 715/915] [mlir][python] Add bindings for mlirDenseElementsAttrGet (#91389) This change adds bindings for `mlirDenseElementsAttrGet` which accepts a list of MLIR attributes and constructs a DenseElementsAttr. This allows for creating `DenseElementsAttr`s of types not natively supported by Python (e.g. BF16) without requiring other dependencies (e.g. `numpy` + `ml-dtypes`). --- mlir/lib/Bindings/Python/IRAttributes.cpp | 77 +++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index dda2003ba..b5f31aa5d 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -15,6 +15,7 @@ #include "PybindUtils.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" @@ -72,6 +73,27 @@ or 255), then a splat will be created. type or if the buffer does not meet expectations. )"; +static const char kDenseElementsAttrGetFromListDocstring[] = + R"(Gets a DenseElementsAttr from a Python list of attributes. + +Note that it can be expensive to construct attributes individually. +For a large number of elements, consider using a Python buffer or array instead. + +Args: + attrs: A list of attributes. + type: The desired shape and type of the resulting DenseElementsAttr. + If not provided, the element type is determined based on the type + of the 0th attribute and the shape is `[len(attrs)]`. + context: Explicit context, if not from context manager. + +Returns: + DenseElementsAttr on success. + +Raises: + ValueError: If the type of the attributes does not match the type + specified by `shaped_type`. +)"; + static const char kDenseResourceElementsAttrGetFromBufferDocstring[] = R"(Gets a DenseResourceElementsAttr from a Python buffer or array. @@ -647,6 +669,57 @@ class PyDenseElementsAttribute static constexpr const char *pyClassName = "DenseElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; + static PyDenseElementsAttribute + getFromList(py::list attributes, std::optional explicitType, + DefaultingPyMlirContext contextWrapper) { + + const size_t numAttributes = py::len(attributes); + if (numAttributes == 0) + throw py::value_error("Attributes list must be non-empty."); + + MlirType shapedType; + if (explicitType) { + if ((!mlirTypeIsAShaped(*explicitType) || + !mlirShapedTypeHasStaticShape(*explicitType))) { + + std::string message; + llvm::raw_string_ostream os(message); + os << "Expected a static ShapedType for the shaped_type parameter: " + << py::repr(py::cast(*explicitType)); + throw py::value_error(os.str()); + } + shapedType = *explicitType; + } else { + SmallVector shape{static_cast(numAttributes)}; + shapedType = mlirRankedTensorTypeGet( + shape.size(), shape.data(), + mlirAttributeGetType(pyTryCast(attributes[0])), + mlirAttributeGetNull()); + } + + SmallVector mlirAttributes; + mlirAttributes.reserve(numAttributes); + for (const py::handle &attribute : attributes) { + MlirAttribute mlirAttribute = pyTryCast(attribute); + MlirType attrType = mlirAttributeGetType(mlirAttribute); + mlirAttributes.push_back(mlirAttribute); + + if (!mlirTypeEqual(mlirShapedTypeGetElementType(shapedType), attrType)) { + std::string message; + llvm::raw_string_ostream os(message); + os << "All attributes must be of the same type and match " + << "the type parameter: expected=" << py::repr(py::cast(shapedType)) + << ", but got=" << py::repr(py::cast(attrType)); + throw py::value_error(os.str()); + } + } + + MlirAttribute elements = mlirDenseElementsAttrGet( + shapedType, mlirAttributes.size(), mlirAttributes.data()); + + return PyDenseElementsAttribute(contextWrapper->getRef(), elements); + } + static PyDenseElementsAttribute getFromBuffer(py::buffer array, bool signless, std::optional explicitType, @@ -883,6 +956,10 @@ class PyDenseElementsAttribute py::arg("type") = py::none(), py::arg("shape") = py::none(), py::arg("context") = py::none(), kDenseElementsAttrGetDocstring) + .def_static("get", PyDenseElementsAttribute::getFromList, + py::arg("attrs"), py::arg("type") = py::none(), + py::arg("context") = py::none(), + kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") From daa9bd0e31c692e9e9e2a05097957a765260aadd Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 24 May 2024 09:02:17 +0200 Subject: [PATCH 716/915] [mlir][py] NFC: remove exception-based isa from linalg module (#92556) When this code was written, we didn't have proper isinstance support for operation classes in Python. Now we do, so there is no reason to keep the expensive exception-based flow. --- mlir/python/mlir/dialects/linalg/__init__.py | 5 ++--- mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py | 10 +--------- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 6e4cb1bd6..8fb1227ee 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -55,7 +55,6 @@ # TODO: guard against surprises and fail create Runtime Custom Ops with # the same name as existing Core Named Ops. from .opdsl.ops.core_named_ops import * -from .opdsl.lang.emitter import isa from ...ir import * from .._ods_common import get_op_result_or_value as _get_op_result_or_value @@ -71,7 +70,7 @@ def transpose( if len(outs) > 1: raise ValueError(f"{outs=} must have length 1.") init = _get_op_result_or_value(outs[0]) - result_types = [init.type] if isa(RankedTensorType, init.type) else [] + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] op = TransposeOp( result=result_types, @@ -93,7 +92,7 @@ def broadcast( if len(outs) > 1: raise ValueError(f"{outs=} must have length 1.") init = _get_op_result_or_value(outs[0]) - result_types = [init.type] if isa(RankedTensorType, init.type) else [] + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] op = BroadcastOp( result=result_types, diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py index 845b533db..254458a97 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py @@ -31,14 +31,6 @@ ValueList = Union[Sequence[Value], OpResultList] -def isa(cls: Type, ty: Type): - try: - cls(ty) - return True - except ValueError: - return False - - def prepare_common_structured_op( op_config: LinalgStructuredOpConfig, *ins: Value, @@ -127,7 +119,7 @@ def prepare_common_structured_op( op_config, in_arg_defs, ins, out_arg_defs, outs ) - result_types = [t for t in out_types if isa(RankedTensorType, t)] + result_types = [t for t in out_types if isinstance(t, RankedTensorType)] # Initialize the type dictionary with the predefined types. type_mapping = dict() # type: Dict[str, Type] From 39c161a247180f23f4279d4b299d391e29eb59d2 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Fri, 24 May 2024 23:15:18 +0200 Subject: [PATCH 717/915] [mlir] expose -debug-only equivalent to C and Python (#93175) These are useful for finer-grain debugging and complement the already exposed global debug flag. --- mlir/include/mlir-c/Debug.h | 13 +++++++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 15 ++++++++++++++- mlir/lib/CAPI/Debug/Debug.cpp | 18 ++++++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Debug.h b/mlir/include/mlir-c/Debug.h index 2502f2fa2..7dad73500 100644 --- a/mlir/include/mlir-c/Debug.h +++ b/mlir/include/mlir-c/Debug.h @@ -21,6 +21,19 @@ MLIR_CAPI_EXPORTED void mlirEnableGlobalDebug(bool enable); /// Retuns `true` if the global debugging flag is set, false otherwise. MLIR_CAPI_EXPORTED bool mlirIsGlobalDebugEnabled(); +/// Sets the current debug type, similarly to `-debug-only=type` in the +/// command-line tools. Note that global debug should be enabled for any output +/// to be produced. +MLIR_CAPI_EXPORTED void mlirSetGlobalDebugType(const char *type); + +/// Sets multiple current debug types, similarly to `-debug-only=type1,type2" in +/// the command-line tools. Note that global debug should be enabled for any +/// output to be produced. +MLIR_CAPI_EXPORTED void mlirSetGlobalDebugTypes(const char **types, intptr_t n); + +/// Checks if `type` is set as the current debug type. +MLIR_CAPI_EXPORTED bool mlirIsCurrentDebugType(const char *type); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 01678a971..2b2792ea6 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -240,7 +240,20 @@ struct PyGlobalDebugFlag { // Debug flags. py::class_(m, "_GlobalDebug", py::module_local()) .def_property_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag"); + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + .def_static( + "set_types", + [](const std::string &type) { + mlirSetGlobalDebugType(type.c_str()); + }, + "types"_a, "Sets specific debug types to be produced by LLVM") + .def_static("set_types", [](const std::vector &types) { + std::vector pointers; + pointers.reserve(types.size()); + for (const std::string &str : types) + pointers.push_back(str.c_str()); + mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); + }); } }; diff --git a/mlir/lib/CAPI/Debug/Debug.cpp b/mlir/lib/CAPI/Debug/Debug.cpp index 288ecd601..320ece499 100644 --- a/mlir/lib/CAPI/Debug/Debug.cpp +++ b/mlir/lib/CAPI/Debug/Debug.cpp @@ -16,3 +16,21 @@ void mlirEnableGlobalDebug(bool enable) { llvm::DebugFlag = enable; } bool mlirIsGlobalDebugEnabled() { return llvm::DebugFlag; } + +void mlirSetGlobalDebugType(const char *type) { + // Depending on the NDEBUG flag, this name can be either a function or a macro + // that expands to something that isn't a funciton call, so we cannot + // explicitly prefix it with `llvm::` or declare `using` it. + using namespace llvm; + setCurrentDebugType(type); +} + +void mlirSetGlobalDebugTypes(const char **types, intptr_t n) { + using namespace llvm; + setCurrentDebugTypes(types, n); +} + +bool mlirIsCurrentDebugType(const char *type) { + using namespace llvm; + return isCurrentDebugType(type); +} From 5c0e2adb88e6e6fee233d6b39268fff786872902 Mon Sep 17 00:00:00 2001 From: Abid Qadeer Date: Tue, 28 May 2024 12:22:44 +0100 Subject: [PATCH 718/915] [mlir] Add missing fields in DICompositeTypeAttr. (#93226) The fortran arrays use 'dataLocation', 'rank', 'allocated' and 'associated' fields of the DICompositeType. These were not available in 'DICompositeTypeAttr'. This PR adds the missing fields. --------- Co-authored-by: Tobias Gysi --- mlir/include/mlir-c/Dialect/LLVM.h | 4 +++- mlir/lib/CAPI/Dialect/LLVM.cpp | 10 ++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 446d36434..e754318d6 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -239,7 +239,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, - uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements); + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, + MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, + MlirAttribute associated); /// Creates a LLVM DIDerivedType attribute. Note that `dwarfAddressSpace` is an /// optional field, where `MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL` indicates null diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 9a2874963..f6fb2cbed 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -163,7 +163,9 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, - uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements) { + uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, + MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, + MlirAttribute associated) { SmallVector elementsStorage; elementsStorage.reserve(nElements); @@ -173,7 +175,11 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), - [](Attribute a) { return cast(a); }))); + [](Attribute a) { return cast(a); }), + cast(unwrap(dataLocation)), + cast(unwrap(rank)), + cast(unwrap(allocated)), + cast(unwrap(associated)))); } MlirAttribute mlirLLVMDIDerivedTypeAttrGet( From 2c1aa7bf7a162b906f45a59f8a41084b26226c99 Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Wed, 29 May 2024 08:43:13 +0200 Subject: [PATCH 719/915] [mlir][python] Yield results of `scf.for_` (#93610) Using `for_` is very hand with python bindings. Currently, it doesn't support results, we had to fallback to two lines scf.for. This PR yields results of scf.for in `for_` --------- Co-authored-by: Maksim Levental --- mlir/python/mlir/dialects/scf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index dad737798..7025f6e0f 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -132,8 +132,8 @@ def for_( iter_args = tuple(for_op.inner_iter_args) with InsertionPoint(for_op.body): if len(iter_args) > 1: - yield iv, iter_args + yield iv, iter_args, for_op.results elif len(iter_args) == 1: - yield iv, iter_args[0] + yield iv, iter_args[0], for_op.results[0] else: yield iv From c97c37c9c37d889aefd87f2fdaf3f3157f242983 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 29 May 2024 05:55:05 -0500 Subject: [PATCH 720/915] [mlir][linalg] Add linalg.conv_2d_ngchw_gfchw_q to named ops (#92136) Adds a named op: linalg.conv_2d_ngchw_gfchw_q. This op is similar to linalg.conv_2d_ngchw_gfchw, but additionally incorporates zero point offset corrections. --- .../linalg/opdsl/ops/core_named_ops.py | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index d73428a0f..43410aaa6 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -958,6 +958,41 @@ def conv_2d_ngchw_gfchw( ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) +@linalg_structured_op +def conv_2d_ngchw_gfchw_q( + I=TensorDef( + T1, S.N, S.G, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW + ), + K=TensorDef(T2, S.G, S.FG, S.C, S.KH, S.KW), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.G, S.FG, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution with zero-point offsets. + + Layout: + * Input: NGCHW. + * Kernel: GFCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.g, D.fg, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.g, D.fg, D.oh, D.ow] += ( + TypeFn.cast_signed( + U, I[D.n, D.g, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) + - TypeFn.cast_signed(U, KZp) + ) + + @linalg_structured_op def conv_3d_ndhwc_dhwcf( I=TensorDef( From 55ce6fffe5b2ab6519d34caf1191cf168be0f8c4 Mon Sep 17 00:00:00 2001 From: Bimo Date: Thu, 30 May 2024 13:01:40 +0800 Subject: [PATCH 721/915] [MLIR][Python] add ctype python binding support for bf16 (#92489) Since bf16 is supported by mlir, similar to complex128/complex64/float16, we need an implementation of bf16 ctype in Python binding. Furthermore, to resolve the absence of bf16 support in NumPy, a third-party package [ml_dtypes ](https://github.com/jax-ml/ml_dtypes) is introduced to add bf16 extension, and the same approach was used in `torch-mlir` project. See motivation and discussion in: https://discourse.llvm.org/t/how-to-run-executionengine-with-bf16-dtype-in-mlir-python-bindings/79025 --- mlir/python/mlir/runtime/np_to_memref.py | 19 +++++++++++++++++++ mlir/python/requirements.txt | 3 ++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index f6b706f9b..882b27519 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -7,6 +7,12 @@ import numpy as np import ctypes +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes provides some optional low precision data-types for NumPy. + ml_dtypes = None + class C128(ctypes.Structure): """A ctype representation for MLIR's Double Complex.""" @@ -26,6 +32,12 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] +class BF16(ctypes.Structure): + """A ctype representation for MLIR's BFloat16.""" + + _fields_ = [("bf16", ctypes.c_int16)] + + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -35,6 +47,8 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16 + if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: + return BF16 return np.ctypeslib.as_ctypes_type(dtp) @@ -46,6 +60,11 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16") + assert not ( + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == BF16: + return array.view("bfloat16") return array diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index acd6dbb25..6ec63e43a 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 +ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file From 5d2cbd54f2557a40595ff49c628a14523551bd9a Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 29 May 2024 23:21:04 -0600 Subject: [PATCH 722/915] Revert "[MLIR][Python] add ctype python binding support for bf16" (#93771) Reverts llvm/llvm-project#92489 This broke the bots. --- mlir/python/mlir/runtime/np_to_memref.py | 19 ------------------- mlir/python/requirements.txt | 3 +-- 2 files changed, 1 insertion(+), 21 deletions(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 882b27519..f6b706f9b 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -7,12 +7,6 @@ import numpy as np import ctypes -try: - import ml_dtypes -except ModuleNotFoundError: - # The third-party ml_dtypes provides some optional low precision data-types for NumPy. - ml_dtypes = None - class C128(ctypes.Structure): """A ctype representation for MLIR's Double Complex.""" @@ -32,12 +26,6 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] -class BF16(ctypes.Structure): - """A ctype representation for MLIR's BFloat16.""" - - _fields_ = [("bf16", ctypes.c_int16)] - - # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -47,8 +35,6 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16 - if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: - return BF16 return np.ctypeslib.as_ctypes_type(dtp) @@ -60,11 +46,6 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16") - assert not ( - array.dtype == BF16 and ml_dtypes is None - ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" - if array.dtype == BF16: - return array.view("bfloat16") return array diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 6ec63e43a..acd6dbb25 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 -ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file From 95e6343a33a505ea68e34c3e0987f9c17cce17f4 Mon Sep 17 00:00:00 2001 From: "Oleksandr \"Alex\" Zinenko" Date: Thu, 30 May 2024 10:06:02 +0200 Subject: [PATCH 723/915] [mlir][py] invalidate nested operations when parent is deleted (#93339) When an operation is erased in Python, its children may still be in the "live" list inside Python bindings. After this, if some of the newly allocated operations happen to reuse the same pointer address, this will trigger an assertion in the bindings. This assertion would be incorrect because the operations aren't actually live. Make sure we remove the children operations from the "live" list when erasing the parent. This also concentrates responsibility over the removal from the "live" list and invalidation in a single place. Note that this requires the IR to be sufficiently structurally valid so a walk through it can succeed. If this invariant was broken by, e.g, C++ pass called from Python, there isn't much we can do. --- mlir/lib/Bindings/Python/IRCore.cpp | 35 ++++++++++++++++++----------- mlir/lib/Bindings/Python/IRModule.h | 7 ++++++ 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2b2792ea6..de20632b4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -697,6 +697,17 @@ void PyMlirContext::clearOperationsInside(MlirOperation op) { clearOperationsInside(opRef->getOperation()); } +void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { + MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, + void *userData) { + PyMlirContextRef &contextRef = *static_cast(userData); + contextRef->clearOperation(op); + return MlirWalkResult::MlirWalkResultAdvance; + }; + mlirOperationWalk(op.getOperation(), invalidatingCallback, + &op.getOperation().getContext(), MlirWalkPreOrder); +} + size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } pybind11::object PyMlirContext::contextEnter() { @@ -1125,12 +1136,16 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - auto &liveOperations = getContext()->liveOperations; - assert(liveOperations.count(operation.ptr) == 1 && - "destroying operation not in live map"); - liveOperations.erase(operation.ptr); - if (!isAttached()) { - mlirOperationDestroy(operation); + + // Otherwise, invalidate the operation and remove it from live map when it is + // attached. + if (isAttached()) { + getContext()->clearOperation(*this); + } else { + // And destroy it when it is detached, i.e. owned by Python, in which case + // all nested operations must be invalidated at removed from the live map as + // well. + erase(); } } @@ -1540,14 +1555,8 @@ py::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - // TODO: Fix memory hazards when erasing a tree of operations for which a deep - // Python reference to a child operation is live. All children should also - // have their `valid` bit set to false. - auto &liveOperations = getContext()->liveOperations; - if (liveOperations.count(operation.ptr)) - liveOperations.erase(operation.ptr); + getContext()->clearOperationAndInside(*this); mlirOperationDestroy(operation); - valid = false; } //------------------------------------------------------------------------------ diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b038a0c54..8c34c11f7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,6 +218,8 @@ class PyMlirContext { /// This is useful for when some non-bindings code destroys the operation and /// the bindings need to made aware. For example, in the case when pass /// manager is run. + /// + /// Note that this does *NOT* clear the nested operations. void clearOperation(MlirOperation op); /// Clears all operations nested inside the given op using @@ -225,6 +227,10 @@ class PyMlirContext { void clearOperationsInside(PyOperationBase &op); void clearOperationsInside(MlirOperation op); + /// Clears the operaiton _and_ all operations inside using + /// `clearOperation(MlirOperation)`. + void clearOperationAndInside(PyOperationBase &op); + /// Gets the count of live modules associated with this context. /// Used for testing. size_t getLiveModuleCount(); @@ -246,6 +252,7 @@ class PyMlirContext { private: PyMlirContext(MlirContext context); + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an From 1de87bba8f5c9770e56956fb02105c9b4969e11a Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Wed, 5 Jun 2024 22:10:55 +0000 Subject: [PATCH 724/915] [mlir][python]Python Bindings for select edit operations on Block arguments (#94305) The PR implements MLIR Python Bindings for a few simple edit operations on Block arguments, namely, `add_argument`, `erase_argument`, and `erase_arguments`. --- mlir/include/mlir-c/IR.h | 3 +++ mlir/lib/Bindings/Python/IRCore.cpp | 13 +++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 3 files changed, 20 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 32abacf35..e3d69b770 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -858,6 +858,9 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, MlirLocation loc); +/// Erase the argument at 'index' and remove it from the argument list. +MLIR_CAPI_EXPORTED void mlirBlockEraseArgument(MlirBlock block, unsigned index); + /// Inserts an argument of the specified type at a specified index to the block. /// Returns the newly added argument. MLIR_CAPI_EXPORTED MlirValue mlirBlockInsertArgument(MlirBlock block, diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index de20632b4..4b6b54dc1 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3238,6 +3238,19 @@ void mlir::python::populateIRCore(py::module &m) { return PyBlockArgumentList(self.getParentOperation(), self.get()); }, "Returns a list of block arguments.") + .def( + "add_argument", + [](PyBlock &self, const PyType &type, const PyLocation &loc) { + return mlirBlockAddArgument(self.get(), type, loc); + }, + "Append an argument of the specified type to the block and returns " + "the newly added argument.") + .def( + "erase_argument", + [](PyBlock &self, unsigned index) { + return mlirBlockEraseArgument(self.get(), index); + }, + "Erase the argument at 'index' and remove it from the argument list.") .def_property_readonly( "operations", [](PyBlock &self) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index a72cd247e..4e823c866 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -906,6 +906,10 @@ MlirValue mlirBlockAddArgument(MlirBlock block, MlirType type, return wrap(unwrap(block)->addArgument(unwrap(type), unwrap(loc))); } +void mlirBlockEraseArgument(MlirBlock block, unsigned index) { + return unwrap(block)->eraseArgument(index); +} + MlirValue mlirBlockInsertArgument(MlirBlock block, intptr_t pos, MlirType type, MlirLocation loc) { return wrap(unwrap(block)->insertArgument(pos, unwrap(type), unwrap(loc))); From 3829e87933c2ac9f2c2d7ea16f061f4ba565a24f Mon Sep 17 00:00:00 2001 From: Abid Qadeer Date: Fri, 7 Jun 2024 09:59:47 +0100 Subject: [PATCH 725/915] [MLIR] Translate DIStringType. (#94480) This PR handle translation of DIStringType. Mostly mechanical changes to translate DIStringType to/from DIStringTypeAttr. The 'stringLength' field is 'DIVariable' in DIStringType. As there was no `DIVariableAttr` previously, it has been added to ease the translation. --------- Co-authored-by: Tobias Gysi --- mlir/include/mlir-c/Dialect/LLVM.h | 6 ++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index e754318d6..902b45444 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -251,6 +251,12 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIDerivedTypeAttrGet( MlirAttribute baseType, uint64_t sizeInBits, uint32_t alignInBits, uint64_t offsetInBits, int64_t dwarfAddressSpace, MlirAttribute extraData); +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIStringTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + uint32_t alignInBits, MlirAttribute stringLength, + MlirAttribute stringLengthExp, MlirAttribute stringLocationExp, + MlirLLVMTypeEncoding encoding); + /// Constant to represent std::nullopt for dwarfAddressSpace to omit the field. #define MLIR_CAPI_DWARF_ADDRESS_SPACE_NULL -1 diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index f6fb2cbed..754c94511 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -195,6 +195,18 @@ MlirAttribute mlirLLVMDIDerivedTypeAttrGet( addressSpace, cast(unwrap(extraData)))); } +MlirAttribute mlirLLVMDIStringTypeAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, + uint32_t alignInBits, MlirAttribute stringLength, + MlirAttribute stringLengthExp, MlirAttribute stringLocationExp, + MlirLLVMTypeEncoding encoding) { + return wrap(DIStringTypeAttr::get( + unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, alignInBits, + cast(unwrap(stringLength)), + cast(unwrap(stringLengthExp)), + cast(unwrap(stringLocationExp)), encoding)); +} + MlirAttribute mlirLLVMDIDerivedTypeAttrGetBaseType(MlirAttribute diDerivedType) { return wrap(cast(unwrap(diDerivedType)).getBaseType()); From 6a5bfdf103dacc979e1bbabf7fc0c0a1ab4fdac9 Mon Sep 17 00:00:00 2001 From: Egor Ospadov Date: Sun, 9 Jun 2024 23:06:46 -0400 Subject: [PATCH 726/915] [mlir][python] Fix attribute registration in ir.py (#94615) This PR fixes attribute registration for `SI8Attr` and `UI8Attr` in `ir.py`. --- mlir/python/mlir/ir.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 80c965b2d..a9ac765fe 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -68,7 +68,7 @@ def _si1Attr(x, context): @register_attribute_builder("SI8Attr") -def _i8Attr(x, context): +def _si8Attr(x, context): return IntegerAttr.get(IntegerType.get_signed(8, context=context), x) @@ -93,7 +93,7 @@ def _ui1Attr(x, context): @register_attribute_builder("UI8Attr") -def _i8Attr(x, context): +def _ui8Attr(x, context): return IntegerAttr.get(IntegerType.get_unsigned(8, context=context), x) From 1566c2385ae7b84d62f7785f40bfde7e21ff714b Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Tue, 11 Jun 2024 14:46:43 +0100 Subject: [PATCH 727/915] Updated the annotations of Python bindings (#92733) --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 118 ++++++++++++++--------- 1 file changed, 72 insertions(+), 46 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 586bf7f8e..1e1b2a834 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -479,7 +479,7 @@ class AffineExpr: class Attribute: @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Attribute: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Attribute: """ Parses an attribute from an assembly form. Raises an MLIRError on failure. """ @@ -520,7 +520,7 @@ class Attribute: class Type: @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Type: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Type: """ Parses the assembly form of a type. @@ -741,7 +741,7 @@ class AffineMap: def results(self) -> "AffineMapExprList": ... class AffineMapAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(affine_map: AffineMap) -> AffineMapAttr: """ @@ -779,7 +779,7 @@ class AffineSymbolExpr(AffineExpr): def position(self) -> int: ... class ArrayAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr: """ @@ -823,7 +823,7 @@ class AttrBuilder: """ class BF16Type(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> BF16Type: """ @@ -909,6 +909,11 @@ class BlockArgument(Value): def owner(self) -> Block: ... class BlockArgumentList: + @overload + def __getitem__(self, arg0: int) -> BlockArgument: ... + @overload + def __getitem__(self, arg0: slice) -> BlockArgumentList: ... + def __len__(self) -> int: ... def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... @property def types(self) -> List[Type]: ... @@ -955,7 +960,7 @@ class BoolAttr(Attribute): """ class ComplexType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(arg0: Type) -> ComplexType: """ @@ -1016,7 +1021,7 @@ class Context: class DenseBoolArrayAttr(Attribute): @staticmethod def get( - values: List[bool], context: Optional[Context] = None + values: Sequence[bool], context: Optional[Context] = None ) -> DenseBoolArrayAttr: """ Gets a uniqued dense array attribute @@ -1113,7 +1118,7 @@ class DenseElementsAttr(Attribute): class DenseF32ArrayAttr(Attribute): @staticmethod def get( - values: List[float], context: Optional[Context] = None + values: Sequence[float], context: Optional[Context] = None ) -> DenseF32ArrayAttr: """ Gets a uniqued dense array attribute @@ -1141,7 +1146,7 @@ class DenseF32ArrayIterator: class DenseF64ArrayAttr(Attribute): @staticmethod def get( - values: List[float], context: Optional[Context] = None + values: Sequence[float], context: Optional[Context] = None ) -> DenseF64ArrayAttr: """ Gets a uniqued dense array attribute @@ -1167,6 +1172,14 @@ class DenseF64ArrayIterator: def __next__(self) -> float: ... class DenseFPElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Optional[Type] = None, + shape: Optional[List[int]] = None, + context: Optional[Context] = None, + ) -> DenseFPElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> float: ... @@ -1180,7 +1193,7 @@ class DenseFPElementsAttr(DenseElementsAttr): class DenseI16ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1206,7 +1219,7 @@ class DenseI16ArrayIterator: class DenseI32ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1232,7 +1245,7 @@ class DenseI32ArrayIterator: class DenseI64ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1258,7 +1271,7 @@ class DenseI64ArrayIterator: class DenseI8ArrayAttr(Attribute): @staticmethod - def get(values: List[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: + def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: """ Gets a uniqued dense array attribute """ @@ -1283,6 +1296,14 @@ class DenseI8ArrayIterator: def __next__(self) -> int: ... class DenseIntElementsAttr(DenseElementsAttr): + @staticmethod + def get( + array: Buffer, + signless: bool = True, + type: Optional[Type] = None, + shape: Optional[List[int]] = None, + context: Optional[Context] = None, + ) -> DenseIntElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... def __getitem__(self, arg0: int) -> int: ... @@ -1422,7 +1443,7 @@ class Dialects: def __getitem__(self, arg0: str) -> Dialect: ... class DictAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(value: Dict = {}, context: Optional[Context] = None) -> DictAttr: """ @@ -1453,7 +1474,7 @@ class FloatType(Type): """ class F16Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F16Type: """ @@ -1466,7 +1487,7 @@ class F16Type(FloatType): def typeid(self) -> TypeID: ... class F32Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F32Type: """ @@ -1479,7 +1500,7 @@ class F32Type(FloatType): def typeid(self) -> TypeID: ... class F64Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> F64Type: """ @@ -1513,7 +1534,7 @@ class FlatSymbolRefAttr(Attribute): """ class Float8E4M3B11FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: """ @@ -1526,7 +1547,7 @@ class Float8E4M3B11FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E4M3FNType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNType: """ @@ -1539,7 +1560,7 @@ class Float8E4M3FNType(FloatType): def typeid(self) -> TypeID: ... class Float8E4M3FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: """ @@ -1552,7 +1573,7 @@ class Float8E4M3FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E5M2FNUZType(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: """ @@ -1565,7 +1586,7 @@ class Float8E5M2FNUZType(FloatType): def typeid(self) -> TypeID: ... class Float8E5M2Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> Float8E5M2Type: """ @@ -1578,7 +1599,7 @@ class Float8E5M2Type(FloatType): def typeid(self) -> TypeID: ... class FloatAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(type: Type, value: float, loc: Optional[Location] = None) -> FloatAttr: """ @@ -1612,7 +1633,7 @@ class FloatAttr(Attribute): """ class FloatTF32Type(FloatType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> FloatTF32Type: """ @@ -1625,7 +1646,7 @@ class FloatTF32Type(FloatType): def typeid(self) -> TypeID: ... class FunctionType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( inputs: List[Type], results: List[Type], context: Optional[Context] = None @@ -1650,7 +1671,7 @@ class FunctionType(Type): def typeid(self) -> TypeID: ... class IndexType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> IndexType: """ @@ -1766,7 +1787,7 @@ class InsertionPoint: """ class IntegerAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(type: Type, value: int) -> IntegerAttr: """ @@ -1855,7 +1876,7 @@ class IntegerSetConstraintList: def __len__(self) -> int: ... class IntegerType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get_signed(width: int, context: Optional[Context] = None) -> IntegerType: """ @@ -1967,7 +1988,7 @@ class Location: """ class MemRefType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int], @@ -2007,7 +2028,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str, context: Optional[Context] = None) -> Module: + def parse(asm: str | bytes, context: Optional[Context] = None) -> Module: """ Parses a module's assembly format from a string. @@ -2064,7 +2085,7 @@ class NamedAttribute: """ class NoneType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> NoneType: """ @@ -2130,7 +2151,12 @@ class OpResultList: class OpSuccessors: def __add__(self, arg0: OpSuccessors) -> List[Block]: ... + @overload + def __getitem__(self, arg0: int) -> Block: ... + @overload + def __getitem__(self, arg0: slice) -> OpSuccessors: ... def __setitem__(self, arg0: int, arg1: Block) -> None: ... + def __len__(self) -> int: ... class OpView(_OperationBase): _ODS_OPERAND_SEGMENTS: ClassVar[None] = ... @@ -2154,7 +2180,7 @@ class OpView(_OperationBase): @classmethod def parse( cls: _Type[_TOperation], - source: str, + source: str | bytes, *, source_name: str = "", context: Optional[Context] = None, @@ -2174,7 +2200,7 @@ class OpView(_OperationBase): """ class OpaqueAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( dialect_namespace: str, @@ -2204,7 +2230,7 @@ class OpaqueAttr(Attribute): def typeid(self) -> TypeID: ... class OpaqueType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( dialect_namespace: str, buffer: str, context: Optional[Context] = None @@ -2262,7 +2288,7 @@ class Operation(_OperationBase): """ @staticmethod def parse( - source: str, *, source_name: str = "", context: Optional[Context] = None + source: str | bytes, *, source_name: str = "", context: Optional[Context] = None ) -> Operation: """ Parses an operation. Supports both text assembly format and binary bytecode format. @@ -2290,7 +2316,7 @@ class OperationList: def __len__(self) -> int: ... class RankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int], @@ -2443,7 +2469,7 @@ class ShapedTypeComponents: """ class StridedLayoutAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( offset: int, strides: List[int], context: Optional[Context] = None @@ -2477,9 +2503,9 @@ class StridedLayoutAttr(Attribute): def typeid(self) -> TypeID: ... class StringAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod - def get(value: str, context: Optional[Context] = None) -> StringAttr: + def get(value: str | bytes, context: Optional[Context] = None) -> StringAttr: """ Gets a uniqued string attribute """ @@ -2554,9 +2580,9 @@ class SymbolTable: def insert(self, operation: _OperationBase) -> Attribute: ... class TupleType(Type): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod - def get_Tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: + def get_tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: """ Create a Tuple type """ @@ -2576,7 +2602,7 @@ class TupleType(Type): def typeid(self) -> TypeID: ... class TypeAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(value: Type, context: Optional[Context] = None) -> TypeAttr: """ @@ -2603,7 +2629,7 @@ class TypeID: def _CAPIPtr(self) -> object: ... class UnitAttr(Attribute): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(context: Optional[Context] = None) -> UnitAttr: """ @@ -2618,7 +2644,7 @@ class UnitAttr(Attribute): def typeid(self) -> TypeID: ... class UnrankedMemRefType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( element_type: Type, memory_space: Attribute, loc: Optional[Location] = None @@ -2638,7 +2664,7 @@ class UnrankedMemRefType(ShapedType): def typeid(self) -> TypeID: ... class UnrankedTensorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get(element_type: Type, loc: Optional[Location] = None) -> UnrankedTensorType: """ @@ -2651,7 +2677,7 @@ class UnrankedTensorType(ShapedType): def typeid(self) -> TypeID: ... class VectorType(ShapedType): - static_typeid: ClassVar[TypeID] # value = + static_typeid: ClassVar[TypeID] @staticmethod def get( shape: List[int], From 40fcae0a2e9e1a02734f3d243a6fd5a615cd19a5 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 11 Jun 2024 07:45:12 -0700 Subject: [PATCH 728/915] [mlir] Add PDL C & Python usage (#94714) Following a rather direct approach to expose PDL usage from C and then Python. This doesn't yes plumb through adding support for custom matchers through this interface, so constrained to basics initially. This also exposes greedy rewrite driver. Only way currently to define patterns is via PDL (just to keep small). The creation of the PDL pattern module could be improved to avoid folks potentially accessing the module used to construct it post construction. No ergonomic work done yet. --------- Signed-off-by: Jacques Pienaar --- mlir/include/mlir-c/Bindings/Python/Interop.h | 21 ++++ mlir/include/mlir-c/Rewrite.h | 60 ++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 21 ++++ mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/lib/Bindings/Python/MainModule.cpp | 4 + mlir/lib/Bindings/Python/Rewrite.cpp | 110 ++++++++++++++++++ mlir/lib/Bindings/Python/Rewrite.h | 22 ++++ mlir/lib/CAPI/Transforms/CMakeLists.txt | 3 + mlir/lib/CAPI/Transforms/Rewrite.cpp | 83 +++++++++++++ mlir/python/CMakeLists.txt | 2 + mlir/python/mlir/dialects/pdl.py | 8 +- mlir/python/mlir/rewrite.py | 5 + 12 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 mlir/include/mlir-c/Rewrite.h create mode 100644 mlir/lib/Bindings/Python/Rewrite.cpp create mode 100644 mlir/lib/Bindings/Python/Rewrite.h create mode 100644 mlir/lib/CAPI/Transforms/Rewrite.cpp create mode 100644 mlir/python/mlir/rewrite.py diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 0a36e97c2..a33190c38 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -39,6 +39,7 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Pass.h" +#include "mlir-c/Rewrite.h" // The 'mlir' Python package is relocatable and supports co-existing in multiple // projects. Each project must define its outer package prefix with this define @@ -284,6 +285,26 @@ static inline MlirModule mlirPythonCapsuleToModule(PyObject *capsule) { return module; } +/** Creates a capsule object encapsulating the raw C-API + * MlirFrozenRewritePatternSet. + * The returned capsule does not extend or affect ownership of any Python + * objects that reference the module in any way. */ +static inline PyObject * +mlirPythonFrozenRewritePatternSetToCapsule(MlirFrozenRewritePatternSet pm) { + return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(pm), + MLIR_PYTHON_CAPSULE_PASS_MANAGER, NULL); +} + +/** Extracts an MlirFrozenRewritePatternSet from a capsule as produced from + * mlirPythonFrozenRewritePatternSetToCapsule. If the capsule is not of the + * right type, then a null module is returned. */ +static inline MlirFrozenRewritePatternSet +mlirPythonCapsuleToFrozenRewritePatternSet(PyObject *capsule) { + void *ptr = PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_PASS_MANAGER); + MlirFrozenRewritePatternSet pm = {ptr}; + return pm; +} + /** Creates a capsule object encapsulating the raw C-API MlirPassManager. * The returned capsule does not extend or affect ownership of any Python * objects that reference the module in any way. */ diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h new file mode 100644 index 000000000..45218a1cd --- /dev/null +++ b/mlir/include/mlir-c/Rewrite.h @@ -0,0 +1,60 @@ +//===-- mlir-c/Rewrite.h - Helpers for C API to Rewrites ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the registration and creation method for +// rewrite patterns. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_REWRITE_H +#define MLIR_C_REWRITE_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir/Config/mlir-config.h" + +//===----------------------------------------------------------------------===// +/// Opaque type declarations (see mlir-c/IR.h for more details). +//===----------------------------------------------------------------------===// + +#define DEFINE_C_API_STRUCT(name, storage) \ + struct name { \ + storage *ptr; \ + }; \ + typedef struct name name + +DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); +DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); +DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); + +MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet +mlirFreezeRewritePattern(MlirRewritePatternSet op); + +MLIR_CAPI_EXPORTED void +mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( + MlirModule op, MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig); + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); + +MLIR_CAPI_EXPORTED MlirPDLPatternModule +mlirPDLPatternModuleFromModule(MlirModule op); + +MLIR_CAPI_EXPORTED void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op); + +MLIR_CAPI_EXPORTED MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +#undef DEFINE_C_API_STRUCT + +#endif // MLIR_C_REWRITE_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index d8f22c7aa..ebf50109f 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -198,6 +198,27 @@ struct type_caster { }; }; +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + PYBIND11_TYPE_CASTER(MlirFrozenRewritePatternSet, + _("MlirFrozenRewritePatternSet")); + bool load(handle src, bool) { + py::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle cast(MlirFrozenRewritePatternSet v, return_value_policy, + handle) { + py::object capsule = py::reinterpret_steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + /// Casts object <-> MlirOperation. template <> struct type_caster { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 8c34c11f7..f49efcd50 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -22,6 +22,7 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" +#include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 17272472c..8da1ab16a 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -11,6 +11,7 @@ #include "Globals.h" #include "IRModule.h" #include "Pass.h" +#include "Rewrite.h" namespace py = pybind11; using namespace mlir; @@ -116,6 +117,9 @@ PYBIND11_MODULE(_mlir, m) { populateIRInterfaces(irModule); populateIRTypes(irModule); + auto rewriteModule = m.def_submodule("rewrite", "MLIR Rewrite Bindings"); + populateRewriteSubmodule(rewriteModule); + // Define and populate PassManager submodule. auto passModule = m.def_submodule("passmanager", "MLIR Pass Management Bindings"); diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp new file mode 100644 index 000000000..1d8128be9 --- /dev/null +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -0,0 +1,110 @@ +//===- Rewrite.cpp - Rewrite ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "Rewrite.h" + +#include "IRModule.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Rewrite.h" +#include "mlir/Config/mlir-config.h" + +namespace py = pybind11; +using namespace mlir; +using namespace py::literals; +using namespace mlir::python; + +namespace { + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +/// Owning Wrapper around a PDLPatternModule. +class PyPDLPatternModule { +public: + PyPDLPatternModule(MlirPDLPatternModule module) : module(module) {} + PyPDLPatternModule(PyPDLPatternModule &&other) noexcept + : module(other.module) { + other.module.ptr = nullptr; + } + ~PyPDLPatternModule() { + if (module.ptr != nullptr) + mlirPDLPatternModuleDestroy(module); + } + MlirPDLPatternModule get() { return module; } + +private: + MlirPDLPatternModule module; +}; +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + +/// Owning Wrapper around a FrozenRewritePatternSet. +class PyFrozenRewritePatternSet { +public: + PyFrozenRewritePatternSet(MlirFrozenRewritePatternSet set) : set(set) {} + PyFrozenRewritePatternSet(PyFrozenRewritePatternSet &&other) noexcept + : set(other.set) { + other.set.ptr = nullptr; + } + ~PyFrozenRewritePatternSet() { + if (set.ptr != nullptr) + mlirFrozenRewritePatternSetDestroy(set); + } + MlirFrozenRewritePatternSet get() { return set; } + + pybind11::object getCapsule() { + return py::reinterpret_steal( + mlirPythonFrozenRewritePatternSetToCapsule(get())); + } + + static pybind11::object createFromCapsule(pybind11::object capsule) { + MlirFrozenRewritePatternSet rawPm = + mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + if (rawPm.ptr == nullptr) + throw py::error_already_set(); + return py::cast(PyFrozenRewritePatternSet(rawPm), + py::return_value_policy::move); + } + +private: + MlirFrozenRewritePatternSet set; +}; + +} // namespace + +/// Create the `mlir.rewrite` here. +void mlir::python::populateRewriteSubmodule(py::module &m) { + //---------------------------------------------------------------------------- + // Mapping of the top-level PassManager + //---------------------------------------------------------------------------- +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH + py::class_(m, "PDLModule", py::module_local()) + .def(py::init<>([](MlirModule module) { + return mlirPDLPatternModuleFromModule(module); + }), + "module"_a, "Create a PDL module from the given module.") + .def("freeze", [](PyPDLPatternModule &self) { + return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( + mlirRewritePatternSetFromPDLPatternModule(self.get()))); + }); +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg + py::class_(m, "FrozenRewritePatternSet", + py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, + &PyFrozenRewritePatternSet::createFromCapsule); + m.def( + "apply_patterns_and_fold_greedily", + [](MlirModule module, MlirFrozenRewritePatternSet set) { + auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); + if (mlirLogicalResultIsFailure(status)) + // FIXME: Not sure this is the right error to throw here. + throw py::value_error("pattern application failed to converge"); + }, + "module"_a, "set"_a, + "Applys the given patterns to the given module greedily while folding " + "results."); +} diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h new file mode 100644 index 000000000..997b80add --- /dev/null +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -0,0 +1,22 @@ +//===- Rewrite.h - Rewrite Submodules of pybind module --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_REWRITE_H +#define MLIR_BINDINGS_PYTHON_REWRITE_H + +#include "PybindUtils.h" + +namespace mlir { +namespace python { + +void populateRewriteSubmodule(pybind11::module &m); + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_REWRITE_H diff --git a/mlir/lib/CAPI/Transforms/CMakeLists.txt b/mlir/lib/CAPI/Transforms/CMakeLists.txt index 2638025a8..6c67aa09f 100644 --- a/mlir/lib/CAPI/Transforms/CMakeLists.txt +++ b/mlir/lib/CAPI/Transforms/CMakeLists.txt @@ -1,6 +1,9 @@ add_mlir_upstream_c_api_library(MLIRCAPITransforms Passes.cpp + Rewrite.cpp LINK_LIBS PUBLIC + MLIRIR MLIRTransforms + MLIRTransformUtils ) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp new file mode 100644 index 000000000..0de195839 --- /dev/null +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -0,0 +1,83 @@ +//===- Rewrite.cpp - C API for Rewrite Patterns ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Rewrite.h" +#include "mlir-c/Transforms.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return *(static_cast(module.ptr)); +} + +inline MlirRewritePatternSet wrap(mlir::RewritePatternSet *module) { + return {module}; +} + +inline mlir::FrozenRewritePatternSet * +unwrap(MlirFrozenRewritePatternSet module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirFrozenRewritePatternSet wrap(mlir::FrozenRewritePatternSet *module) { + return {module}; +} + +MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op) { + auto *m = new mlir::FrozenRewritePatternSet(std::move(unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} + +void mlirFrozenRewritePatternSetDestroy(MlirFrozenRewritePatternSet op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirLogicalResult +mlirApplyPatternsAndFoldGreedily(MlirModule op, + MlirFrozenRewritePatternSet patterns, + MlirGreedyRewriteDriverConfig) { + return wrap( + mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); +} + +#if MLIR_ENABLE_PDL_IN_PATTERNMATCH +inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { + assert(module.ptr && "unexpected null module"); + return static_cast(module.ptr); +} + +inline MlirPDLPatternModule wrap(mlir::PDLPatternModule *module) { + return {module}; +} + +MlirPDLPatternModule mlirPDLPatternModuleFromModule(MlirModule op) { + return wrap(new mlir::PDLPatternModule( + mlir::OwningOpRef(unwrap(op)))); +} + +void mlirPDLPatternModuleDestroy(MlirPDLPatternModule op) { + delete unwrap(op); + op.ptr = nullptr; +} + +MlirRewritePatternSet +mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op) { + auto *m = new mlir::RewritePatternSet(std::move(*unwrap(op))); + op.ptr = nullptr; + return wrap(m); +} +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d8f2d1989..d03036e17 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -21,6 +21,7 @@ declare_mlir_python_sources(MLIRPythonSources.Core.Python _mlir_libs/__init__.py ir.py passmanager.py + rewrite.py dialects/_ods_common.py # The main _mlir module has submodules: include stubs from each. @@ -448,6 +449,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRModule.cpp IRTypes.cpp Pass.cpp + Rewrite.cpp # Headers must be included explicitly so they are installed. Globals.h diff --git a/mlir/python/mlir/dialects/pdl.py b/mlir/python/mlir/dialects/pdl.py index db07dc50a..b7b8430ce 100644 --- a/mlir/python/mlir/dialects/pdl.py +++ b/mlir/python/mlir/dialects/pdl.py @@ -6,7 +6,7 @@ from ._pdl_ops_gen import _Dialect from .._mlir_libs._mlirDialectsPDL import * from .._mlir_libs._mlirDialectsPDL import OperationType - +from ..extras.meta import region_op try: from ..ir import * @@ -127,6 +127,9 @@ def body(self): return self.regions[0].blocks[0] +pattern = region_op(PatternOp.__base__) + + @_ods_cext.register_operation(_Dialect, replace=True) class ReplaceOp(ReplaceOp): """Specialization for PDL replace op class.""" @@ -195,6 +198,9 @@ def body(self): return self.regions[0].blocks[0] +rewrite = region_op(RewriteOp) + + @_ods_cext.register_operation(_Dialect, replace=True) class TypeOp(TypeOp): """Specialization for PDL type op class.""" diff --git a/mlir/python/mlir/rewrite.py b/mlir/python/mlir/rewrite.py new file mode 100644 index 000000000..5bc1bba7a --- /dev/null +++ b/mlir/python/mlir/rewrite.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._mlir_libs._mlir.rewrite import * From 4eee08c32ecc432a61878f3b659579e7fe129803 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 12 Jun 2024 05:17:13 -0700 Subject: [PATCH 729/915] [MLIR][python] include Rewrite.h (#95226) --- mlir/python/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index d03036e17..23187f256 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -456,6 +456,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRModule.h Pass.h PybindUtils.h + Rewrite.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS From 7d8e8e0644982e7dc4ee5fcc0ec74c23b143bfd7 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 17 Jun 2024 12:11:49 -0700 Subject: [PATCH 730/915] [mlirc] Add missing extern C (#95829) This was missing being wrapped in extern C block. Don't know why didn't fail elsewhere, but failed on Windows build while linking Python libs. Signed-off-by: Jacques Pienaar --- mlir/include/mlir-c/Rewrite.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index 45218a1cd..bed93045f 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -19,6 +19,10 @@ #include "mlir-c/Support.h" #include "mlir/Config/mlir-config.h" +#ifdef __cplusplus +extern "C" { +#endif + //===----------------------------------------------------------------------===// /// Opaque type declarations (see mlir-c/IR.h for more details). //===----------------------------------------------------------------------===// @@ -57,4 +61,8 @@ mlirRewritePatternSetFromPDLPatternModule(MlirPDLPatternModule op); #undef DEFINE_C_API_STRUCT +#ifdef __cplusplus +} +#endif + #endif // MLIR_C_REWRITE_H From 754de741ec1d2639f4b43737015591946dda2d6b Mon Sep 17 00:00:00 2001 From: Jonas Rickert Date: Thu, 20 Jun 2024 17:15:08 +0200 Subject: [PATCH 731/915] [mlir] Expose skipRegions option for Op printing in the C and Python bindings (#96150) The MLIR C and Python Bindings expose various methods from `mlir::OpPrintingFlags` . This PR adds a binding for the `skipRegions` method, which allows to skip the printing of Regions when printing Ops. It also exposes this option as parameter in the python `get_asm` and `print` methods --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 22 +++++++++++++++------- mlir/lib/Bindings/Python/IRModule.h | 5 +++-- mlir/lib/CAPI/IR/IR.cpp | 3 +++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 +++ 5 files changed, 28 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index e3d69b770..694591fd9 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -450,6 +450,10 @@ mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags); MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags); +/// Skip printing regions. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags); + //===----------------------------------------------------------------------===// // Bytecode printing flags API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 4b6b54dc1..c12f75e7d 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -108,6 +108,7 @@ static const char kOperationPrintDocstring[] = and report failures in a more robust fashion. Set this to True if doing this in order to avoid running a redundant verification. If the IR is actually invalid, behavior is undefined. + skip_regions: Whether to skip printing regions. Defaults to False. )"; static const char kOperationPrintStateDocstring[] = @@ -1221,7 +1222,7 @@ void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, py::object fileObject, - bool binary) { + bool binary, bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) @@ -1239,6 +1240,8 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsUseLocalScope(flags); if (assumeVerified) mlirOpPrintingFlagsAssumeVerified(flags); + if (skipRegions) + mlirOpPrintingFlagsSkipRegions(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), @@ -1314,7 +1317,7 @@ py::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified) { + bool assumeVerified, bool skipRegions) { py::object fileObject; if (binary) { fileObject = py::module::import("io").attr("BytesIO")(); @@ -1328,7 +1331,8 @@ py::object PyOperationBase::getAsm(bool binary, /*useLocalScope=*/useLocalScope, /*assumeVerified=*/assumeVerified, /*fileObject=*/fileObject, - /*binary=*/binary); + /*binary=*/binary, + /*skipRegions=*/skipRegions); return fileObject.attr("getvalue")(); } @@ -3043,7 +3047,8 @@ void mlir::python::populateIRCore(py::module &m) { /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, /*useLocalScope=*/false, - /*assumeVerified=*/false); + /*assumeVerified=*/false, + /*skipRegions=*/false); }, "Returns the assembly form of the operation.") .def("print", @@ -3053,7 +3058,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool>(&PyOperationBase::print), + bool, py::object, bool, bool>( + &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. py::arg("large_elements_limit") = py::none(), py::arg("enable_debug_info") = false, @@ -3061,7 +3067,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, py::arg("assume_verified") = false, py::arg("file") = py::none(), - py::arg("binary") = false, kOperationPrintDocstring) + py::arg("binary") = false, py::arg("skip_regions") = false, + kOperationPrintDocstring) .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) @@ -3073,7 +3080,8 @@ void mlir::python::populateIRCore(py::module &m) { py::arg("pretty_debug_info") = false, py::arg("print_generic_op_form") = false, py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, kOperationGetAsmDocstring) + py::arg("assume_verified") = false, py::arg("skip_regions") = false, + kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f49efcd50..172898cfd 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -574,14 +574,15 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, bool binary); + bool assumeVerified, py::object fileObject, bool binary, + bool skipRegions); void print(PyAsmState &state, py::object fileObject, bool binary); pybind11::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified); + bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. void writeBytecode(const pybind11::object &fileObject, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 4e823c866..2edc311e2 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -219,6 +219,9 @@ void mlirOpPrintingFlagsAssumeVerified(MlirOpPrintingFlags flags) { unwrap(flags)->assumeVerified(); } +void mlirOpPrintingFlagsSkipRegions(MlirOpPrintingFlags flags) { + unwrap(flags)->skipRegions(); +} //===----------------------------------------------------------------------===// // Bytecode printing flags API. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 1e1b2a834..317e68807 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -209,6 +209,7 @@ class _OperationBase: print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, + skip_regions: bool = False, ) -> Union[io.BytesIO, io.StringIO]: """ Gets the assembly form of the operation with all options available. @@ -256,6 +257,7 @@ class _OperationBase: assume_verified: bool = False, file: Optional[Any] = None, binary: bool = False, + skip_regions: bool = False, ) -> None: """ Prints the assembly form of the operation to a file like object. @@ -281,6 +283,7 @@ class _OperationBase: and report failures in a more robust fashion. Set this to True if doing this in order to avoid running a redundant verification. If the IR is actually invalid, behavior is undefined. + skip_regions: Whether to skip printing regions. Defaults to False. """ def verify(self) -> bool: """ From 72ce09886efff266ef2ba80d3a002415b1ca50a3 Mon Sep 17 00:00:00 2001 From: muneebkhan85 <150162960+muneebkhan85@users.noreply.github.com> Date: Fri, 21 Jun 2024 15:39:43 +0100 Subject: [PATCH 732/915] [MLIR] Add continuous tiling to transform dialect (#82792) This patch enables continuous tiling of a target structured op using diminishing tile sizes. In cases where the tensor dimensions are not exactly divisible by the tile size, we are left with leftover tensor chunks that are irregularly tiled. This approach enables tiling of the leftover chunk with a smaller tile size and repeats this process recursively using exponentially diminishing tile sizes. This eventually generates a chain of loops that apply tiling using diminishing tile sizes. Adds `continuous_tile_sizes` op to the transform dialect. This op, when given a tile size and a dimension, computes a series of diminishing tile sizes that can be used to tile the target along the given dimension. Additionally, this op also generates a series of chunk sizes that the corresponding tile sizes should be applied to along the given dimension. Adds `multiway` attribute to `transform.structured.split` that enables multiway splitting of a single target op along the given dimension, as specified in a list enumerating the chunk sizes. --- .../python/mlir/dialects/transform/structured.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 2c49ef096..41051c0d5 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -432,25 +432,25 @@ def __init__( self, target: Union[Operation, Value], dimension: Union[int, Attribute], - split_point: Union[int, Operation, Value, Attribute], + chunk_sizes: Union[int, Operation, Value, Attribute], *, loc=None, ip=None, ): - if isinstance(split_point, int): - static_split_point = split_point - dynamic_split_point = None + if isinstance(chunk_sizes, int): + static_chunk_sizes = chunk_sizes + dynamic_chunk_sizes = None else: - static_split_point = ShapedType.get_dynamic_size() - dynamic_split_point = split_point + static_chunk_sizes = ShapedType.get_dynamic_size() + dynamic_chunk_sizes = chunk_sizes super().__init__( target.type, target.type, target, dimension=dimension, - static_split_point=static_split_point, - dynamic_split_point=dynamic_split_point, + static_chunk_sizes=static_chunk_sizes, + dynamic_chunk_sizes=dynamic_chunk_sizes, loc=loc, ip=ip, ) From 4b2d7d2e33dbee6a8c08e29d36261f417be794dc Mon Sep 17 00:00:00 2001 From: Guray Ozen Date: Mon, 24 Jun 2024 11:39:22 +0200 Subject: [PATCH 733/915] [mlir][gpu] Add py binding for AsyncTokenType (#96466) The PR adds py binding for `AsyncTokenType` --- mlir/include/mlir-c/Dialect/GPU.h | 8 ++++++++ mlir/lib/Bindings/Python/DialectGPU.cpp | 14 ++++++++++++++ mlir/lib/CAPI/Dialect/GPU.cpp | 12 ++++++++++++ 3 files changed, 34 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h index 2adf73ddf..c42ff61f9 100644 --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -19,6 +19,14 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(GPU, gpu); +//===-------------------------------------------------------------------===// +// AsyncTokenType +//===-------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAGPUAsyncTokenType(MlirType type); + +MLIR_CAPI_EXPORTED MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx); + //===---------------------------------------------------------------------===// // ObjectAttr //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index 1f68bfc6f..a9e339b50 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -25,6 +25,20 @@ using namespace mlir::python::adaptors; PYBIND11_MODULE(_mlirDialectsGPU, m) { m.doc() = "MLIR GPU Dialect"; + //===-------------------------------------------------------------------===// + // AsyncTokenType + //===-------------------------------------------------------------------===// + + auto mlirGPUAsyncTokenType = + mlir_type_subclass(m, "AsyncTokenType", mlirTypeIsAGPUAsyncTokenType); + + mlirGPUAsyncTokenType.def_classmethod( + "get", + [](py::object cls, MlirContext ctx) { + return cls(mlirGPUAsyncTokenTypeGet(ctx)); + }, + "Gets an instance of AsyncTokenType in the same context", py::arg("cls"), + py::arg("ctx") = py::none()); //===-------------------------------------------------------------------===// // ObjectAttr diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp index e471e8cd9..0acebb230 100644 --- a/mlir/lib/CAPI/Dialect/GPU.cpp +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -15,6 +15,18 @@ using namespace mlir; MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(GPU, gpu, gpu::GPUDialect) +//===-------------------------------------------------------------------===// +// AsyncTokenType +//===-------------------------------------------------------------------===// + +bool mlirTypeIsAGPUAsyncTokenType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirGPUAsyncTokenTypeGet(MlirContext ctx) { + return wrap(gpu::AsyncTokenType::get(unwrap(ctx))); +} + //===---------------------------------------------------------------------===// // ObjectAttr //===---------------------------------------------------------------------===// From a2f21fd67ebc74b427f937df2dfde2de92459656 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Sat, 29 Jun 2024 12:48:11 -0700 Subject: [PATCH 734/915] Rename f8E4M3 to f8E4M3FN in mlir.extras.types py package (#97102) Currently `f8E4M3` is mapped to `Float8E4M3FNType`. This PR renames `f8E4M3` to `f8E4M3FN` to accurately reflect the actual type. This PR is needed to avoid names conflict in upcoming PR which will add IEEE 754 `Float8E4M3Type`. https://github.com/llvm/llvm-project/pull/97118 Add f8E4M3 IEEE 754 type Maksim, can you review this PR? @makslevental ? --- mlir/python/mlir/extras/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index db9e8229f..b93c46b17 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -68,7 +68,7 @@ def ui(width): bf16 = lambda: BF16Type.get() f8E5M2 = lambda: Float8E5M2Type.get() -f8E4M3 = lambda: Float8E4M3FNType.get() +f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() none = lambda: NoneType.get() From 4104409c8b444db657ef6a98761208ea8653efa2 Mon Sep 17 00:00:00 2001 From: Bimo Date: Mon, 1 Jul 2024 23:44:40 +0800 Subject: [PATCH 735/915] [MLIR][Python] add value attr for PyAffineMapAttribute (#97254) Similar to other attributes in Binding, the `PyAffineMapAttribute` should include a value attribute to enable users to directly retrieve the `AffineMap` from the `AffineMapAttr`. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index b5f31aa5d..b4049bd79 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -142,6 +142,8 @@ class PyAffineMapAttribute : public PyConcreteAttribute { return PyAffineMapAttribute(affineMap.getContext(), attr); }, py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_property_readonly("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); } }; From 8f8c4766552394c80436c685dfb0731a560fb837 Mon Sep 17 00:00:00 2001 From: Ramkumar Ramachandra Date: Tue, 2 Jul 2024 10:42:33 +0100 Subject: [PATCH 736/915] mlir/LogicalResult: move into llvm (#97309) This patch is part of a project to move the Presburger library into LLVM. --- mlir/include/mlir/CAPI/Support.h | 6 +++--- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/CAPI/Support.h b/mlir/include/mlir/CAPI/Support.h index 622745256..89a460375 100644 --- a/mlir/include/mlir/CAPI/Support.h +++ b/mlir/include/mlir/CAPI/Support.h @@ -17,9 +17,9 @@ #include "mlir-c/Support.h" #include "mlir/CAPI/Wrap.h" -#include "mlir/Support/LogicalResult.h" #include "mlir/Support/TypeID.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" namespace llvm { class ThreadPoolInterface; @@ -35,13 +35,13 @@ inline llvm::StringRef unwrap(MlirStringRef ref) { return llvm::StringRef(ref.data, ref.length); } -inline MlirLogicalResult wrap(mlir::LogicalResult res) { +inline MlirLogicalResult wrap(llvm::LogicalResult res) { if (mlir::succeeded(res)) return mlirLogicalResultSuccess(); return mlirLogicalResultFailure(); } -inline mlir::LogicalResult unwrap(MlirLogicalResult res) { +inline llvm::LogicalResult unwrap(MlirLogicalResult res) { return mlir::success(mlirLogicalResultIsSuccess(res)); } diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index c94c07014..01bb71d92 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -16,7 +16,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" -#include "mlir/Support/LogicalResult.h" #include From 33a19d339b44dc18f60be9472e02989331490ec3 Mon Sep 17 00:00:00 2001 From: Bimo Date: Fri, 5 Jul 2024 09:23:12 +0800 Subject: [PATCH 737/915] [MLIR][Python] fix class name of powf and negf in linalg (#97696) The following logic can lead to a class name mismatch when using `linalg.powf` in Python. This PR fixed the issue and also renamed `NegfOp` to `NegFOp` in linalg to adhere to the naming convention, as exemplified by `arith::NegFOp`. https://github.com/llvm/llvm-project/blob/173514d58ec4e6166670f1e37a038df3865c8b96/mlir/python/mlir/dialects/linalg/opdsl/lang/dsl.py#L140-L143 ``` # linalg.powf(arg0, arg1, outs=[init_result.result]) NotImplementedError: Unknown named op_name / op_class_name: powf / PowfOp ``` --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 43410aaa6..cbb2d1cec 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -96,7 +96,7 @@ def floor( O[None] = UnaryFn.floor(I[None]) -@linalg_structured_op +@linalg_structured_op(op_class_name="NegFOp") def negf( I=TensorDef(T1), O=TensorDef(T1, output=True), @@ -330,7 +330,7 @@ def min( O[None] = BinaryFn.min_signed(lhs[None], rhs[None]) -@linalg_structured_op +@linalg_structured_op(op_class_name="PowFOp") def powf( lhs=TensorDef(T1), rhs=TensorDef(T1), From 4e84104b581dddad58cb0829b9752a7d15826719 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 5 Jul 2024 10:43:51 -0500 Subject: [PATCH 738/915] [mlir][python] auto attribute casting (#97786) --- .../mlir/Bindings/Python/PybindAdaptors.h | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index ebf50109f..df4b9bf71 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -406,21 +406,25 @@ class pure_subclass { class mlir_attribute_subclass : public pure_subclass { public: using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); /// Subclasses by looking up the super-class dynamically. mlir_attribute_subclass(py::handle scope, const char *attrClassName, - IsAFunctionTy isaFunction) + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : mlir_attribute_subclass( scope, attrClassName, isaFunction, py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute")) {} + .attr("Attribute"), + getTypeIDFunction) {} /// Subclasses with a provided mlir.ir.Attribute super-class. This must /// be used if the subclass is being defined in the same extension module /// as the mlir.ir class (otherwise, it will trigger a recursive /// initialization). mlir_attribute_subclass(py::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, const py::object &superCls) + IsAFunctionTy isaFunction, const py::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { // Casting constructor. Note that it hard, if not impossible, to properly // call chain to parent `__init__` in pybind11 due to its special handling @@ -454,6 +458,20 @@ class mlir_attribute_subclass : public pure_subclass { "isinstance", [isaFunction](MlirAttribute other) { return isaFunction(other); }, py::arg("other_attribute")); + def("__repr__", [superCls, captureTypeName](py::object self) { + return py::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } } }; From 78328d8d8e50767b980d61cd2a03b373e0ef5be3 Mon Sep 17 00:00:00 2001 From: Brendan Hansknecht Date: Mon, 8 Jul 2024 14:12:49 -0700 Subject: [PATCH 739/915] [mlir][c-api] expose elideLargeResourceString (#98050) Expose `elideLargeResourceString` to the c api. This was done in the same way as `elideLargeElementsAttrs` is exposed. The docs were grabbed from the `elideLargeResourceString` method and forwarded here. --- mlir/include/mlir-c/IR.h | 8 ++++++++ mlir/lib/CAPI/IR/IR.cpp | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 694591fd9..b8a6f08b1 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -428,6 +428,14 @@ MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, intptr_t largeElementLimit); +/// Enables the elision of large resources strings by omitting them from the +/// `dialect_resources` section. The `largeResourceLimit` is used to configure +/// what is considered to be a "large" resource by providing an upper limit to +/// the string size. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, + intptr_t largeResourceLimit); + /// Enable or disable printing of debug information (based on `enable`). If /// 'prettyForm' is set to true, debug information is printed in a more readable /// 'pretty' form. Note: The IR generated with 'prettyForm' is not parsable. diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 2edc311e2..5eb531b70 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -202,6 +202,11 @@ void mlirOpPrintingFlagsElideLargeElementsAttrs(MlirOpPrintingFlags flags, unwrap(flags)->elideLargeElementsAttrs(largeElementLimit); } +void mlirOpPrintingFlagsElideLargeResourceString(MlirOpPrintingFlags flags, + intptr_t largeResourceLimit) { + unwrap(flags)->elideLargeResourceString(largeResourceLimit); +} + void mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, bool prettyForm) { unwrap(flags)->enableDebugInfo(enable, /*prettyForm=*/prettyForm); From e61b966fc918d0ec0699fbece977a83084be1525 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Sat, 13 Jul 2024 14:46:44 +0100 Subject: [PATCH 740/915] Remove redundant linalg.matmul_signed (#98615) `linalg.matmul` already has an attribute for casts, defaults to signed but allowed unsigned, so the operation `linalg.matmul_unsigned` is redundant. The generalization test has an example on how to lower to unsigned matmul in linalg. This is the first PR in a list of many that will simplify the linalg operations by using similar attributes. Ref: https://discourse.llvm.org/t/rfc-transpose-attribute-for-linalg-matmul-operations/80092 --- .../linalg/opdsl/ops/core_named_ops.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index cbb2d1cec..3ceee8e37 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -388,24 +388,6 @@ def matmul( C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) -@linalg_structured_op -def matmul_unsigned( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), -): - """Performs an unsigned matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_unsigned(U, A[D.m, D.k]) * TypeFn.cast_unsigned( - U, B[D.k, D.n] - ) - - @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), From 303952a812aaf31e4f2ceca7a7da348a09dfaef8 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Mon, 15 Jul 2024 09:30:56 +0100 Subject: [PATCH 741/915] [MLIR][Linalg] Fix named structured ops yaml file (#98865) Added missing reciprocal to Python file and fixed ErfOp name in yaml file. Now running the bash script yields the same output. --- .../mlir/dialects/linalg/opdsl/lang/comprehension.py | 1 + .../mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 12 ++++++++++++ 2 files changed, 13 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py index 1a198fc5e..4f81a3874 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py @@ -292,6 +292,7 @@ class UnaryFn: ceil = UnaryFnType("ceil") floor = UnaryFnType("floor") negf = UnaryFnType("negf") + reciprocal = UnaryFnType("reciprocal") round = UnaryFnType("round") sqrt = UnaryFnType("sqrt") rsqrt = UnaryFnType("rsqrt") diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 3ceee8e37..67bde8f73 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -108,6 +108,18 @@ def negf( O[None] = UnaryFn.negf(I[None]) +@linalg_structured_op(op_class_name="ReciprocalOp") +def reciprocal( + I=TensorDef(T1), + O=TensorDef(T1, output=True), +): + """Applies reciprocal(x) elementwise. + + No numeric casting is performed on the input operand. + """ + O[None] = UnaryFn.reciprocal(I[None]) + + @linalg_structured_op def round( I=TensorDef(T1), From fba770a277e62ecb0fd6ca61712c6a2894245156 Mon Sep 17 00:00:00 2001 From: Fehr Mathieu Date: Tue, 16 Jul 2024 20:37:11 +0100 Subject: [PATCH 742/915] [mlir] Add RewriterBase to the C API (#98962) This exposes most of the `RewriterBase` methods to the C API. This allows to manipulate both the `IRRewriter` and the `PatternRewriter`. The `IRRewriter` can be created from the C API, while the `PatternRewriter` cannot. The missing operations are the ones taking `Block::iterator` and `Region::iterator` as parameters, as they are not exposed by the C API yet AFAIK. The Python bindings for these methods and classes are not implemented. --- mlir/include/mlir-c/Rewrite.h | 260 +++++++++++++++++++++++++++ mlir/include/mlir/CAPI/Rewrite.h | 23 +++ mlir/lib/CAPI/Transforms/Rewrite.cpp | 243 +++++++++++++++++++++++++ 3 files changed, 526 insertions(+) create mode 100644 mlir/include/mlir/CAPI/Rewrite.h diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index bed93045f..d8f2275b6 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -33,10 +33,266 @@ extern "C" { }; \ typedef struct name name +DEFINE_C_API_STRUCT(MlirRewriterBase, void); DEFINE_C_API_STRUCT(MlirFrozenRewritePatternSet, void); DEFINE_C_API_STRUCT(MlirGreedyRewriteDriverConfig, void); DEFINE_C_API_STRUCT(MlirRewritePatternSet, void); +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +/// Get the MLIR context referenced by the rewriter. +MLIR_CAPI_EXPORTED MlirContext +mlirRewriterBaseGetContext(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Insertion points methods + +// These do not include functions using Block::iterator or Region::iterator, as +// they are not exposed by the C API yet. Similarly for methods using +// `InsertPoint` directly. + +/// Reset the insertion point to no location. Creating an operation without a +/// set insertion point is an error, but this can still be useful when the +/// current insertion point a builder refers to is being removed. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter); + +/// Sets the insertion point to the specified operation, which will cause +/// subsequent insertions to go right before it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified operation, which +/// will cause subsequent insertions to go right after it. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op); + +/// Sets the insertion point to the node after the specified value. If value +/// has a defining operation, sets the insertion point to the node after such +/// defining operation. This will cause subsequent insertions to go right +/// after it. Otherwise, value is a BlockArgument. Sets the insertion point to +/// the start of its block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value); + +/// Sets the insertion point to the start of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block); + +/// Sets the insertion point to the end of the specified block. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block); + +/// Return the block the current insertion point belongs to. Note that the +/// insertion point is not necessarily the end of the block. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter); + +/// Returns the current block of the rewriter. +MLIR_CAPI_EXPORTED MlirBlock +mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning + +// These functions do not include the IRMapper, as it is not yet exposed by the +// C API. + +/// Add new block with 'argTypes' arguments and set the insertion point to the +/// end of it. The block is placed before 'insertBefore'. `locs` contains the +/// locations of the inserted arguments, and should match the size of +/// `argTypes`. +MLIR_CAPI_EXPORTED MlirBlock mlirRewriterBaseCreateBlockBefore( + MlirRewriterBase rewriter, MlirBlock insertBefore, intptr_t nArgTypes, + MlirType const *argTypes, MlirLocation const *locations); + +/// Insert the given operation at the current insertion point and return it. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseInsert(MlirRewriterBase rewriter, MlirOperation op); + +/// Creates a deep copy of the specified operation. +MLIR_CAPI_EXPORTED MlirOperation +mlirRewriterBaseClone(MlirRewriterBase rewriter, MlirOperation op); + +/// Creates a deep copy of this operation but keep the operation regions +/// empty. +MLIR_CAPI_EXPORTED MlirOperation mlirRewriterBaseCloneWithoutRegions( + MlirRewriterBase rewriter, MlirOperation op); + +/// Clone the blocks that belong to "region" before the given position in +/// another region "parent". +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +/// Move the blocks that belong to "region" before the given position in +/// another region "parent". The two regions must be different. The caller +/// is responsible for creating or updating the operation transferring flow +/// of control to the region and passing it the correct block arguments. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, MlirRegion region, + MlirBlock before); + +/// Replace the results of the given (original) operation with the specified +/// list of values (replacements). The result types of the given op and the +/// replacements must match. The original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, MlirOperation op, + intptr_t nValues, MlirValue const *values); + +/// Replace the results of the given (original) operation with the specified +/// new op (replacement). The result types of the two ops must match. The +/// original op is erased. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, MlirOperation newOp); + +/// Erases an operation that is known to have no uses. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, + MlirOperation op); + +/// Erases a block along with all operations inside it. +MLIR_CAPI_EXPORTED void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, + MlirBlock block); + +/// Inline the operations of block 'source' before the operation 'op'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The source block must have no successors. Otherwise, the resulting IR +/// would have unreachable operations. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, MlirBlock source, + MlirOperation op, intptr_t nArgValues, + MlirValue const *argValues); + +/// Inline the operations of block 'source' into the end of block 'dest'. The +/// source block will be deleted and must have no uses. 'argValues' is used to +/// replace the block arguments of 'source' +/// +/// The dest block must have no successors. Otherwise, the resulting IR would +/// have unreachable operation. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, + MlirBlock source, + MlirBlock dest, + intptr_t nArgValues, + MlirValue const *argValues); + +/// Unlink this operation from its current block and insert it right before +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this operation from its current block and insert it right after +/// `existingOp` which may be in the same or another block in the same +/// function. +MLIR_CAPI_EXPORTED void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation existingOp); + +/// Unlink this block and insert it right before `existingBlock`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock); + +/// This method is used to notify the rewriter that an in-place operation +/// modification is about to happen. A call to this function *must* be +/// followed by a call to either `finalizeOpModification` or +/// `cancelOpModification`. This is a minor efficiency win (it avoids creating +/// a new operation and removing the old one) but also often allows simpler +/// code in the client. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method is used to signal the end of an in-place modification of the +/// given operation. This can only be called on operations that were provided +/// to a call to `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// This method cancels a pending in-place modification. This can only be +/// called on operations that were provided to a call to +/// `startOpModification`. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, MlirValue from, + MlirValue to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllValueRangeUsesWith( + MlirRewriterBase rewriter, intptr_t nValues, MlirValue const *from, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, intptr_t nTo, + MlirValue const *to); + +/// Find uses of `from` and replace them with `to`. Also notify the listener +/// about every in-place op modification (for every use that was replaced) +/// and that the `from` operation is about to be replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceAllOpUsesWithOperation( + MlirRewriterBase rewriter, MlirOperation from, MlirOperation to); + +/// Find uses of `from` within `block` and replace them with `to`. Also notify +/// the listener about every in-place op modification (for every use that was +/// replaced). The optional `allUsesReplaced` flag is set to "true" if all +/// uses were replaced. +MLIR_CAPI_EXPORTED void mlirRewriterBaseReplaceOpUsesWithinBlock( + MlirRewriterBase rewriter, MlirOperation op, intptr_t nNewValues, + MlirValue const *newValues, MlirBlock block); + +/// Find uses of `from` and replace them with `to` except if the user is +/// `exceptedUser`. Also notify the listener about every in-place op +/// modification (for every use that was replaced). +MLIR_CAPI_EXPORTED void +mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, MlirValue from, + MlirValue to, MlirOperation exceptedUser); + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +/// Create an IRRewriter and transfer ownership to the caller. +MLIR_CAPI_EXPORTED MlirRewriterBase mlirIRRewriterCreate(MlirContext context); + +/// Create an IRRewriter and transfer ownership to the caller. Additionally +/// set the insertion point before the operation. +MLIR_CAPI_EXPORTED MlirRewriterBase +mlirIRRewriterCreateFromOp(MlirOperation op); + +/// Takes an IRRewriter owned by the caller and destroys it. It is the +/// responsibility of the user to only pass an IRRewriter class. +MLIR_CAPI_EXPORTED void mlirIRRewriterDestroy(MlirRewriterBase rewriter); + +//===----------------------------------------------------------------------===// +/// FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + MLIR_CAPI_EXPORTED MlirFrozenRewritePatternSet mlirFreezeRewritePattern(MlirRewritePatternSet op); @@ -47,6 +303,10 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirApplyPatternsAndFoldGreedily( MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig); +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH DEFINE_C_API_STRUCT(MlirPDLPatternModule, void); diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h new file mode 100644 index 000000000..0e6dcb247 --- /dev/null +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -0,0 +1,23 @@ +//===- Rewrite.h - C API Utils for Core MLIR classes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains declarations of implementation details of the C API for +// rewrite patterns. This file should not be included from C++ code other than +// C API implementation nor from C code. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_CAPI_REWRITE_H +#define MLIR_CAPI_REWRITE_H + +#include "mlir/CAPI/Wrap.h" +#include "mlir/IR/PatternMatch.h" + +DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase); + +#endif // MLIR_CAPIREWRITER_H diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 0de195839..379f09cf5 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -7,15 +7,254 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Rewrite.h" + #include "mlir-c/Transforms.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Rewrite.h" #include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; +//===----------------------------------------------------------------------===// +/// RewriterBase API inherited from OpBuilder +//===----------------------------------------------------------------------===// + +MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getContext()); +} + +//===----------------------------------------------------------------------===// +/// Insertion points methods + +void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { + unwrap(rewriter)->clearInsertionPoint(); +} + +void mlirRewriterBaseSetInsertionPointBefore(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPoint(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfter(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->setInsertionPointAfter(unwrap(op)); +} + +void mlirRewriterBaseSetInsertionPointAfterValue(MlirRewriterBase rewriter, + MlirValue value) { + unwrap(rewriter)->setInsertionPointAfterValue(unwrap(value)); +} + +void mlirRewriterBaseSetInsertionPointToStart(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToStart(unwrap(block)); +} + +void mlirRewriterBaseSetInsertionPointToEnd(MlirRewriterBase rewriter, + MlirBlock block) { + unwrap(rewriter)->setInsertionPointToEnd(unwrap(block)); +} + +MlirBlock mlirRewriterBaseGetInsertionBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getInsertionBlock()); +} + +MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { + return wrap(unwrap(rewriter)->getBlock()); +} + +//===----------------------------------------------------------------------===// +/// Block and operation creation/insertion/cloning + +MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, + MlirBlock insertBefore, + intptr_t nArgTypes, + MlirType const *argTypes, + MlirLocation const *locations) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgTypes, argTypes, args); + SmallVector locs; + ArrayRef unwrappedLocs = unwrapList(nArgTypes, locations, locs); + return wrap(unwrap(rewriter)->createBlock(unwrap(insertBefore), unwrappedArgs, + unwrappedLocs)); +} + +MlirOperation mlirRewriterBaseInsert(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->insert(unwrap(op))); +} + +// Other methods of OpBuilder + +MlirOperation mlirRewriterBaseClone(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->clone(*unwrap(op))); +} + +MlirOperation mlirRewriterBaseCloneWithoutRegions(MlirRewriterBase rewriter, + MlirOperation op) { + return wrap(unwrap(rewriter)->cloneWithoutRegions(*unwrap(op))); +} + +void mlirRewriterBaseCloneRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + + unwrap(rewriter)->cloneRegionBefore(*unwrap(region), unwrap(before)); +} + +//===----------------------------------------------------------------------===// +/// RewriterBase API +//===----------------------------------------------------------------------===// + +void mlirRewriterBaseInlineRegionBefore(MlirRewriterBase rewriter, + MlirRegion region, MlirBlock before) { + unwrap(rewriter)->inlineRegionBefore(*unwrap(region), unwrap(before)); +} + +void mlirRewriterBaseReplaceOpWithValues(MlirRewriterBase rewriter, + MlirOperation op, intptr_t nValues, + MlirValue const *values) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nValues, values, vals); + unwrap(rewriter)->replaceOp(unwrap(op), unwrappedVals); +} + +void mlirRewriterBaseReplaceOpWithOperation(MlirRewriterBase rewriter, + MlirOperation op, + MlirOperation newOp) { + unwrap(rewriter)->replaceOp(unwrap(op), unwrap(newOp)); +} + +void mlirRewriterBaseEraseOp(MlirRewriterBase rewriter, MlirOperation op) { + unwrap(rewriter)->eraseOp(unwrap(op)); +} + +void mlirRewriterBaseEraseBlock(MlirRewriterBase rewriter, MlirBlock block) { + unwrap(rewriter)->eraseBlock(unwrap(block)); +} + +void mlirRewriterBaseInlineBlockBefore(MlirRewriterBase rewriter, + MlirBlock source, MlirOperation op, + intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nArgValues, argValues, vals); + + unwrap(rewriter)->inlineBlockBefore(unwrap(source), unwrap(op), + unwrappedVals); +} + +void mlirRewriterBaseMergeBlocks(MlirRewriterBase rewriter, MlirBlock source, + MlirBlock dest, intptr_t nArgValues, + MlirValue const *argValues) { + SmallVector args; + ArrayRef unwrappedArgs = unwrapList(nArgValues, argValues, args); + unwrap(rewriter)->mergeBlocks(unwrap(source), unwrap(dest), unwrappedArgs); +} + +void mlirRewriterBaseMoveOpBefore(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpBefore(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveOpAfter(MlirRewriterBase rewriter, MlirOperation op, + MlirOperation existingOp) { + unwrap(rewriter)->moveOpAfter(unwrap(op), unwrap(existingOp)); +} + +void mlirRewriterBaseMoveBlockBefore(MlirRewriterBase rewriter, MlirBlock block, + MlirBlock existingBlock) { + unwrap(rewriter)->moveBlockBefore(unwrap(block), unwrap(existingBlock)); +} + +void mlirRewriterBaseStartOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->startOpModification(unwrap(op)); +} + +void mlirRewriterBaseFinalizeOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->finalizeOpModification(unwrap(op)); +} + +void mlirRewriterBaseCancelOpModification(MlirRewriterBase rewriter, + MlirOperation op) { + unwrap(rewriter)->cancelOpModification(unwrap(op)); +} + +void mlirRewriterBaseReplaceAllUsesWith(MlirRewriterBase rewriter, + MlirValue from, MlirValue to) { + unwrap(rewriter)->replaceAllUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceAllValueRangeUsesWith(MlirRewriterBase rewriter, + intptr_t nValues, + MlirValue const *from, + MlirValue const *to) { + SmallVector fromVals; + ArrayRef unwrappedFromVals = unwrapList(nValues, from, fromVals); + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nValues, to, toVals); + unwrap(rewriter)->replaceAllUsesWith(unwrappedFromVals, unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithValueRange(MlirRewriterBase rewriter, + MlirOperation from, + intptr_t nTo, + MlirValue const *to) { + SmallVector toVals; + ArrayRef unwrappedToVals = unwrapList(nTo, to, toVals); + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrappedToVals); +} + +void mlirRewriterBaseReplaceAllOpUsesWithOperation(MlirRewriterBase rewriter, + MlirOperation from, + MlirOperation to) { + unwrap(rewriter)->replaceAllOpUsesWith(unwrap(from), unwrap(to)); +} + +void mlirRewriterBaseReplaceOpUsesWithinBlock(MlirRewriterBase rewriter, + MlirOperation op, + intptr_t nNewValues, + MlirValue const *newValues, + MlirBlock block) { + SmallVector vals; + ArrayRef unwrappedVals = unwrapList(nNewValues, newValues, vals); + unwrap(rewriter)->replaceOpUsesWithinBlock(unwrap(op), unwrappedVals, + unwrap(block)); +} + +void mlirRewriterBaseReplaceAllUsesExcept(MlirRewriterBase rewriter, + MlirValue from, MlirValue to, + MlirOperation exceptedUser) { + unwrap(rewriter)->replaceAllUsesExcept(unwrap(from), unwrap(to), + unwrap(exceptedUser)); +} + +//===----------------------------------------------------------------------===// +/// IRRewriter API +//===----------------------------------------------------------------------===// + +MlirRewriterBase mlirIRRewriterCreate(MlirContext context) { + return wrap(new IRRewriter(unwrap(context))); +} + +MlirRewriterBase mlirIRRewriterCreateFromOp(MlirOperation op) { + return wrap(new IRRewriter(unwrap(op))); +} + +void mlirIRRewriterDestroy(MlirRewriterBase rewriter) { + delete static_cast(unwrap(rewriter)); +} + +//===----------------------------------------------------------------------===// +/// RewritePatternSet and FrozenRewritePatternSet API +//===----------------------------------------------------------------------===// + inline mlir::RewritePatternSet &unwrap(MlirRewritePatternSet module) { assert(module.ptr && "unexpected null module"); return *(static_cast(module.ptr)); @@ -54,6 +293,10 @@ mlirApplyPatternsAndFoldGreedily(MlirModule op, mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); } +//===----------------------------------------------------------------------===// +/// PDLPatternModule API +//===----------------------------------------------------------------------===// + #if MLIR_ENABLE_PDL_IN_PATTERNMATCH inline mlir::PDLPatternModule *unwrap(MlirPDLPatternModule module) { assert(module.ptr && "unexpected null module"); From 0d485d6364c31cf43e68a87b4228479ffe991451 Mon Sep 17 00:00:00 2001 From: Jie Fu Date: Wed, 17 Jul 2024 07:13:16 +0800 Subject: [PATCH 743/915] [mlir] Fix build error (NFC) /llvm-project/mlir/include/mlir/CAPI/Rewrite.h:21:63: error: extra ';' outside of a function is incompatible with C++98 [-Werror,-Wc++98-compat-extra-semi] DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase); ^ 1 error generated. --- mlir/include/mlir/CAPI/Rewrite.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index 0e6dcb247..f0bb9337e 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -18,6 +18,6 @@ #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" -DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase); +DEFINE_C_API_PTR_METHODS(MlirRewriterBase, mlir::RewriterBase) #endif // MLIR_CAPIREWRITER_H From ae97d32c4f8e44f434a439ff2fefa4d41893b845 Mon Sep 17 00:00:00 2001 From: Jordan Rupprecht Date: Wed, 17 Jul 2024 13:50:35 -0500 Subject: [PATCH 744/915] [mlir][NFC] Add rewrite header to fix standalone header compile (#99370) This uses `MlirRewriterBase` from from `mlir-c/Rewrite.h` without including it. --- mlir/include/mlir/CAPI/Rewrite.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/include/mlir/CAPI/Rewrite.h b/mlir/include/mlir/CAPI/Rewrite.h index f0bb9337e..1038c0a57 100644 --- a/mlir/include/mlir/CAPI/Rewrite.h +++ b/mlir/include/mlir/CAPI/Rewrite.h @@ -15,6 +15,7 @@ #ifndef MLIR_CAPI_REWRITE_H #define MLIR_CAPI_REWRITE_H +#include "mlir-c/Rewrite.h" #include "mlir/CAPI/Wrap.h" #include "mlir/IR/PatternMatch.h" From 38432e58baaec82e8c79c269b270377f9abce41f Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Mon, 22 Jul 2024 23:20:28 -0700 Subject: [PATCH 745/915] [MLIR] Add f8E4M3 IEEE 754 type (#97118) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f8E4M3` type to mlir. `f8E4M3` type follows IEEE 754 convention ```c f8E4M3 (IEEE 754) - Exponent bias: 7 - Maximum stored exponent value: 14 (binary 1110) - Maximum unbiased exponent value: 14 - 7 = 7 - Minimum stored exponent value: 1 (binary 0001) - Minimum unbiased exponent value: 1 − 7 = −6 - Precision specifies the total number of bits used for the significand (mantisa), including implicit leading integer bit = 3 + 1 = 4 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 7 - Min exp (unbiased): -6 - Infinities (+/-): S.1111.000 - Zeros (+/-): S.0000.000 - NaNs: S.1111.{001, 010, 011, 100, 101, 110, 111} - Max normal number: S.1110.111 = +/-2^(7) x (1 + 0.875) = +/-240 - Min normal number: S.0001.000 = +/-2^(-6) - Max subnormal number: S.0000.111 = +/-2^(-6) x 0.875 = +/-2^(-9) x 7 - Min subnormal number: S.0000.001 = +/-2^(-6) x 0.125 = +/-2^(-9) ``` Related PRs: - [PR-97179](https://github.com/llvm/llvm-project/pull/97179) [APFloat] Add support for f8E4M3 IEEE 754 type --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 23 ++++++++++++++++++++++- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 99c5e3f46..2212087b9 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -89,6 +89,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E5M2(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E5M2TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E4M3 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3TypeGetTypeID(void); + +/// Checks whether the given type is an f8E4M3 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3(MlirType type); + +/// Creates an f8E4M3 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3TypeGet(MlirContext ctx); + /// Returns the typeID of an Float8E4M3FN type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E4M3FNTypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index e1e4eb999..5e0aebc03 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -143,7 +143,7 @@ class PyFloat8E4M3FNType } }; -/// Floating Point Type subclass - Float8M5E2Type. +/// Floating Point Type subclass - Float8E5M2Type. class PyFloat8E5M2Type : public PyConcreteType { public: static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; @@ -163,6 +163,26 @@ class PyFloat8E5M2Type : public PyConcreteType { } }; +/// Floating Point Type subclass - Float8E4M3Type. +class PyFloat8E4M3Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3TypeGet(context->get()); + return PyFloat8E4M3Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e4m3 type."); + } +}; + /// Floating Point Type subclass - Float8E4M3FNUZ. class PyFloat8E4M3FNUZType : public PyConcreteType { @@ -840,6 +860,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyIndexType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); + PyFloat8E4M3Type::bind(m); PyFloat8E4M3FNUZType::bind(m); PyFloat8E4M3B11FNUZType::bind(m); PyFloat8E5M2FNUZType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 01bb71d92..d50702735 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -97,6 +97,18 @@ MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); } +MlirTypeID mlirFloat8E4M3TypeGetTypeID() { + return wrap(Float8E4M3Type::getTypeID()); +} + +bool mlirTypeIsAFloat8E4M3(MlirType type) { + return unwrap(type).isFloat8E4M3(); +} + +MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E4M3(unwrap(ctx))); +} + MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { return wrap(Float8E4M3FNType::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 317e68807..224e77a3f 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -123,6 +123,7 @@ __all__ = [ "Float8E4M3B11FNUZType", "Float8E4M3FNType", "Float8E4M3FNUZType", + "Float8E4M3Type", "Float8E5M2FNUZType", "Float8E5M2Type", "FloatAttr", @@ -1575,6 +1576,19 @@ class Float8E4M3FNUZType(FloatType): @property def typeid(self) -> TypeID: ... +class Float8E4M3Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Optional[Context] = None) -> Float8E4M3Type: + """ + Create a float8_e4m3 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class Float8E5M2FNUZType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index b93c46b17..fde9909a8 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -14,6 +14,7 @@ F64Type, Float8E4M3B11FNUZType, Float8E4M3FNType, + Float8E4M3Type, Float8E5M2Type, FunctionType, IndexType, @@ -68,6 +69,7 @@ def ui(width): bf16 = lambda: BF16Type.get() f8E5M2 = lambda: Float8E5M2Type.get() +f8E4M3 = lambda: Float8E4M3Type.get() f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() From d2d859f8789b2b108e72a38a6077081aabe0cd1d Mon Sep 17 00:00:00 2001 From: Walter Erquinigo Date: Tue, 23 Jul 2024 19:40:22 -0400 Subject: [PATCH 746/915] [MLIR][DebugInfo] Enable the use of DILocalVariable DIFlags (#100190) This patch enables the use of flags for local variables in debug info. They were defaulted as always zero, but allowing them is pretty trivial. --- mlir/include/mlir-c/Dialect/LLVM.h | 2 +- mlir/lib/CAPI/Dialect/LLVM.cpp | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 902b45444..631b56461 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -309,7 +309,7 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILexicalBlockFileAttrGet( MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( MlirContext ctx, MlirAttribute scope, MlirAttribute name, MlirAttribute diFile, unsigned int line, unsigned int arg, - unsigned int alignInBits, MlirAttribute diType); + unsigned int alignInBits, MlirAttribute diType, int64_t flags); /// Creates a LLVM DISubprogramAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 754c94511..03e2f2be2 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -266,15 +266,14 @@ MlirAttribute mlirLLVMDILexicalBlockFileAttrGet(MlirContext ctx, cast(unwrap(file)), discriminator)); } -MlirAttribute -mlirLLVMDILocalVariableAttrGet(MlirContext ctx, MlirAttribute scope, - MlirAttribute name, MlirAttribute diFile, - unsigned int line, unsigned int arg, - unsigned int alignInBits, MlirAttribute diType) { +MlirAttribute mlirLLVMDILocalVariableAttrGet( + MlirContext ctx, MlirAttribute scope, MlirAttribute name, + MlirAttribute diFile, unsigned int line, unsigned int arg, + unsigned int alignInBits, MlirAttribute diType, int64_t flags) { return wrap(DILocalVariableAttr::get( unwrap(ctx), cast(unwrap(scope)), cast(unwrap(name)), cast(unwrap(diFile)), line, - arg, alignInBits, cast(unwrap(diType)))); + arg, alignInBits, cast(unwrap(diType)), DIFlags(flags))); } MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, From d7f474bceac7e6d3209f6e06301e559a01a3b07b Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Tue, 30 Jul 2024 11:19:26 +0000 Subject: [PATCH 747/915] [mlir] Apply ClangTidy performance finding. --- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index 19171d64d..f2a0ab33c 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -107,6 +107,7 @@ MlirSparseTensorLevelType mlirSparseTensorEncodingAttrBuildLvlType( unsigned size, unsigned n, unsigned m) { std::vector props; + props.reserve(size); for (unsigned i = 0; i < size; i++) props.push_back(static_cast(properties[i])); From f559b247a86a9c074539611ec7418e929b9aa994 Mon Sep 17 00:00:00 2001 From: Bimo Date: Wed, 31 Jul 2024 16:24:27 +0800 Subject: [PATCH 748/915] Reapply "[MLIR][Python] add ctype python binding support for bf16" (#101271) Reapply the PR which was reverted due to built-bots, and now the bots get updated. https://discourse.llvm.org/t/need-a-help-with-the-built-bots/79437 original PR: https://github.com/llvm/llvm-project/pull/92489, reverted in https://github.com/llvm/llvm-project/pull/93771 --- mlir/python/mlir/runtime/np_to_memref.py | 19 +++++++++++++++++++ mlir/python/requirements.txt | 3 ++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index f6b706f9b..882b27519 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -7,6 +7,12 @@ import numpy as np import ctypes +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes provides some optional low precision data-types for NumPy. + ml_dtypes = None + class C128(ctypes.Structure): """A ctype representation for MLIR's Double Complex.""" @@ -26,6 +32,12 @@ class F16(ctypes.Structure): _fields_ = [("f16", ctypes.c_int16)] +class BF16(ctypes.Structure): + """A ctype representation for MLIR's BFloat16.""" + + _fields_ = [("bf16", ctypes.c_int16)] + + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): """Converts dtype to ctype.""" @@ -35,6 +47,8 @@ def as_ctype(dtp): return C64 if dtp == np.dtype(np.float16): return F16 + if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: + return BF16 return np.ctypeslib.as_ctypes_type(dtp) @@ -46,6 +60,11 @@ def to_numpy(array): return array.view("complex64") if array.dtype == F16: return array.view("float16") + assert not ( + array.dtype == BF16 and ml_dtypes is None + ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == BF16: + return array.view("bfloat16") return array diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index acd6dbb25..6ec63e43a 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 \ No newline at end of file +PyYAML>=5.3.1, <=6.0.1 +ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file From 5758a738d6373f931a4f25fabac346b845d2f2d6 Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Fri, 2 Aug 2024 00:22:11 -0700 Subject: [PATCH 749/915] [MLIR] Add f8E3M4 IEEE 754 type (#101230) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f8E3M4` type to mlir. `f8E3M4` type follows IEEE 754 convention ```c f8E3M4 (IEEE 754) - Exponent bias: 3 - Maximum stored exponent value: 6 (binary 110) - Maximum unbiased exponent value: 6 - 3 = 3 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Precision specifies the total number of bits used for the significand (mantissa), including implicit leading integer bit = 4 + 1 = 5 - Follows IEEE 754 conventions for representation of special values - Has Positive and Negative zero - Has Positive and Negative infinity - Has NaNs Additional details: - Max exp (unbiased): 3 - Min exp (unbiased): -2 - Infinities (+/-): S.111.0000 - Zeros (+/-): S.000.0000 - NaNs: S.111.{0,1}⁴ except S.111.0000 - Max normal number: S.110.1111 = +/-2^(6-3) x (1 + 15/16) = +/-2^3 x 31 x 2^(-4) = +/-15.5 - Min normal number: S.001.0000 = +/-2^(1-3) x (1 + 0) = +/-2^(-2) - Max subnormal number: S.000.1111 = +/-2^(-2) x 15/16 = +/-2^(-2) x 15 x 2^(-4) = +/-15 x 2^(-6) - Min subnormal number: S.000.0001 = +/-2^(-2) x 1/16 = +/-2^(-2) x 2^(-4) = +/-2^(-6) ``` Related PRs: - [PR-99698](https://github.com/llvm/llvm-project/pull/99698) [APFloat] Add support for f8E3M4 IEEE 754 type - [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 IEEE 754 type --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 21 +++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 59 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 2212087b9..d698bf476 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -139,6 +139,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E3M4 type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E3M4TypeGetTypeID(void); + +/// Checks whether the given type is an f8E3M4 type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type); + +/// Creates an f8E3M4 type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx); + /// Returns the typeID of an BFloat16 type. MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 5e0aebc03..c3d42c0ef 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -246,6 +246,26 @@ class PyFloat8E5M2FNUZType } }; +/// Floating Point Type subclass - Float8E3M4Type. +class PyFloat8E3M4Type : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E3M4TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E3M4Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E3M4TypeGet(context->get()); + return PyFloat8E3M4Type(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e3m4 type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -864,6 +884,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyFloat8E4M3FNUZType::bind(m); PyFloat8E4M3B11FNUZType::bind(m); PyFloat8E5M2FNUZType::bind(m); + PyFloat8E3M4Type::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyTF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index d50702735..2aa2e922f 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -157,6 +157,18 @@ MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); } +MlirTypeID mlirFloat8E3M4TypeGetTypeID() { + return wrap(Float8E3M4Type::getTypeID()); +} + +bool mlirTypeIsAFloat8E3M4(MlirType type) { + return unwrap(type).isFloat8E3M4(); +} + +MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E3M4(unwrap(ctx))); +} + MlirTypeID mlirBFloat16TypeGetTypeID() { return wrap(BFloat16Type::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 224e77a3f..e3599d3c8 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -120,6 +120,7 @@ __all__ = [ "F32Type", "F64Type", "FlatSymbolRefAttr", + "Float8E3M4Type", "Float8E4M3B11FNUZType", "Float8E4M3FNType", "Float8E4M3FNUZType", @@ -1537,6 +1538,19 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ +class Float8E3M4Type(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Optional[Context] = None) -> Float8E3M4Type: + """ + Create a float8_e3m4 type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class Float8E4M3B11FNUZType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index fde9909a8..fe7c3e25d 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -12,6 +12,7 @@ F16Type, F32Type, F64Type, + Float8E3M4Type, Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3Type, @@ -72,6 +73,7 @@ def ui(width): f8E4M3 = lambda: Float8E4M3Type.get() f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() +f8E3M4 = lambda: Float8E3M4Type.get() none = lambda: NoneType.get() From f4ad9c0bcbf98cb00f73c102ef6a1e184963bf82 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 7 Aug 2024 11:24:15 -0400 Subject: [PATCH 750/915] [mlir] [python] Update PyYAML minimum version to 5.4 and limit ml_dtypes to 0.4.0 (#102178) PyYAML 5.3.1 has a security vulnerability as described here: https://nvd.nist.gov/vuln/detail/CVE-2020-14343. Update the minimum PyYAML version to 5.4. Also limit ml_dtypes version to 0.4.0. --- mlir/python/requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 6ec63e43a..d1b5418cc 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.19.5, <=1.26 pybind11>=2.9.0, <=2.10.3 -PyYAML>=5.3.1, <=6.0.1 -ml_dtypes # provides several NumPy dtype extensions, including the bf16 \ No newline at end of file +PyYAML>=5.4.0, <=6.0.1 +ml_dtypes>=0.1.0, <=0.4.0 # provides several NumPy dtype extensions, including the bf16 From 8fcaddff141706e919b3e90a772d9dbd80a3d30c Mon Sep 17 00:00:00 2001 From: zhicong zhong Date: Mon, 12 Aug 2024 14:37:57 +0800 Subject: [PATCH 751/915] [mlir][linalg] fix linalg.batch_reduce_matmul auto cast (#102585) Fix the auto-cast of `linalg.batch_reduce_matmul` from `cast_to_T(A * cast_to_T(B)) + C` to `cast_to_T(A) * cast_to_T(B) + C` --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 67bde8f73..e4a6ec748 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -592,8 +592,8 @@ def batch_reduce_matmul( """ domain(D.b, D.m, D.n, D.k) implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed( - U, A[D.b, D.m, D.k] * TypeFn.cast_signed(U, B[D.b, D.k, D.n]) + C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( + U, B[D.b, D.k, D.n] ) From 3c1dccb8acb2ca0ad06c3f41777778ef3f40eda4 Mon Sep 17 00:00:00 2001 From: Bimo Date: Mon, 19 Aug 2024 09:06:48 +0800 Subject: [PATCH 752/915] [MLIR][Python] enhance python api for tensor.empty (#103087) Since we have extended `EmptyOp`, maybe we should also provide a corresponding `tensor.empty` method. In the downstream usage, I tend to use APIs with all lowercase letters to create ops, so having a `tensor.empty` to replace the extended `tensor.EmptyOp` would keep my code style consistent. --- mlir/python/mlir/dialects/tensor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 79dd9476a..0b30d1020 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -13,6 +13,7 @@ from typing import Sequence, Union from ._ods_common import _cext as _ods_cext +from ._ods_common import get_op_result_or_op_results as _get_op_result_or_op_results @_ods_cext.register_operation(_Dialect, replace=True) @@ -43,6 +44,18 @@ def __init__( super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) +def empty( + sizes: Sequence[Union[int, Value]], + element_type: Type, + *, + loc=None, + ip=None, +) -> _ods_cext.ir.Value: + return _get_op_result_or_op_results( + EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip) + ) + + generate = region_op( lambda result, dynamic_extents: GenerateOp(result, dynamic_extents), terminator=lambda args: YieldOp(args[0]), From 857d7bf39dbbc18ee528aff1818b22132ca11e8e Mon Sep 17 00:00:00 2001 From: Bimo Date: Tue, 20 Aug 2024 16:13:08 +0800 Subject: [PATCH 753/915] [MLIR][Python] remove unused init python file (#104890) remove unused `__init__.py` under `mlir/python/mlir/extras` --- mlir/python/mlir/extras/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 mlir/python/mlir/extras/__init__.py diff --git a/mlir/python/mlir/extras/__init__.py b/mlir/python/mlir/extras/__init__.py deleted file mode 100644 index e69de29bb..000000000 From 6cc53509412b48cb7e5c13e3a5cfb376286af0e5 Mon Sep 17 00:00:00 2001 From: Christopher Bate Date: Sun, 25 Aug 2024 18:51:22 -0600 Subject: [PATCH 754/915] [mlir] NFC: add missing 'FloatType' to core Python stub file (#105554) The stub class for `FloatType` is present in `ir.pyi`, but it is missing from the `__all__` export list. --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index e3599d3c8..4a2d0e977 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -129,6 +129,7 @@ __all__ = [ "Float8E5M2Type", "FloatAttr", "FloatTF32Type", + "FloatType", "FunctionType", "IndexType", "InferShapedTypeOpInterface", From 306daa4dc87dc744bd3793d3c1516895fa0800a1 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 26 Aug 2024 15:46:11 -0700 Subject: [PATCH 755/915] [mlir][Python] Make `PyShapedType` public (#106105) Make `PyShapedType` public, so that downstream projects can define types that implement the `ShapedType` type interface in Python. --- mlir/include/mlir/Bindings/Python/IRTypes.h | 31 ++++ mlir/lib/Bindings/Python/IRTypes.cpp | 180 ++++++++++---------- 2 files changed, 122 insertions(+), 89 deletions(-) create mode 100644 mlir/include/mlir/Bindings/Python/IRTypes.h diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h new file mode 100644 index 000000000..9afad4c23 --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -0,0 +1,31 @@ +//===- IRTypes.h - Type Interfaces ----------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H +#define MLIR_BINDINGS_PYTHON_IRTYPES_H + +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace mlir { + +/// Shaped Type Interface - ShapedType +class PyShapedType : public python::PyConcreteType { +public: + static const IsAFunctionTy isaFunction; + static constexpr const char *pyClassName = "ShapedType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + void requireHasRank(); +}; + +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index c3d42c0ef..2ee1d89c3 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -10,6 +10,8 @@ #include "PybindUtils.h" +#include "mlir/Bindings/Python/IRTypes.h" + #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" @@ -418,98 +420,98 @@ class PyComplexType : public PyConcreteType { } }; -class PyShapedType : public PyConcreteType { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped; - static constexpr const char *pyClassName = "ShapedType"; - using PyConcreteType::PyConcreteType; +} // namespace - static void bindDerived(ClassTy &c) { - c.def_property_readonly( - "element_type", - [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, - "Returns the element type of the shaped type."); - c.def_property_readonly( - "has_rank", - [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, - "Returns whether the given shaped type is ranked."); - c.def_property_readonly( - "rank", - [](PyShapedType &self) { - self.requireHasRank(); - return mlirShapedTypeGetRank(self); - }, - "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( - "has_static_shape", - [](PyShapedType &self) -> bool { - return mlirShapedTypeHasStaticShape(self); - }, - "Returns whether the given shaped type has a static shape."); - c.def( - "is_dynamic_dim", - [](PyShapedType &self, intptr_t dim) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicDim(self, dim); - }, - py::arg("dim"), - "Returns whether the dim-th dimension of the given shaped type is " - "dynamic."); - c.def( - "get_dim_size", - [](PyShapedType &self, intptr_t dim) { - self.requireHasRank(); - return mlirShapedTypeGetDimSize(self, dim); - }, - py::arg("dim"), - "Returns the dim-th dimension of the given ranked shaped type."); - c.def_static( - "is_dynamic_size", - [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), - "Returns whether the given dimension size indicates a dynamic " - "dimension."); - c.def( - "is_dynamic_stride_or_offset", - [](PyShapedType &self, int64_t val) -> bool { - self.requireHasRank(); - return mlirShapedTypeIsDynamicStrideOrOffset(val); - }, - py::arg("dim_size"), - "Returns whether the given value is used as a placeholder for dynamic " - "strides and offsets in shaped types."); - c.def_property_readonly( - "shape", - [](PyShapedType &self) { - self.requireHasRank(); - - std::vector shape; - int64_t rank = mlirShapedTypeGetRank(self); - shape.reserve(rank); - for (int64_t i = 0; i < rank; ++i) - shape.push_back(mlirShapedTypeGetDimSize(self, i)); - return shape; - }, - "Returns the shape of the ranked shaped type as a list of integers."); - c.def_static( - "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, - "Returns the value used to indicate dynamic dimensions in shaped " - "types."); - c.def_static( - "get_dynamic_stride_or_offset", - []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, - "Returns the value used to indicate dynamic strides or offsets in " - "shaped types."); - } +// Shaped Type Interface - ShapedType +void mlir::PyShapedType::bindDerived(ClassTy &c) { + c.def_property_readonly( + "element_type", + [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, + "Returns the element type of the shaped type."); + c.def_property_readonly( + "has_rank", + [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, + "Returns whether the given shaped type is ranked."); + c.def_property_readonly( + "rank", + [](PyShapedType &self) { + self.requireHasRank(); + return mlirShapedTypeGetRank(self); + }, + "Returns the rank of the given ranked shaped type."); + c.def_property_readonly( + "has_static_shape", + [](PyShapedType &self) -> bool { + return mlirShapedTypeHasStaticShape(self); + }, + "Returns whether the given shaped type has a static shape."); + c.def( + "is_dynamic_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicDim(self, dim); + }, + py::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "dynamic."); + c.def( + "get_dim_size", + [](PyShapedType &self, intptr_t dim) { + self.requireHasRank(); + return mlirShapedTypeGetDimSize(self, dim); + }, + py::arg("dim"), + "Returns the dim-th dimension of the given ranked shaped type."); + c.def_static( + "is_dynamic_size", + [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, + py::arg("dim_size"), + "Returns whether the given dimension size indicates a dynamic " + "dimension."); + c.def( + "is_dynamic_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsDynamicStrideOrOffset(val); + }, + py::arg("dim_size"), + "Returns whether the given value is used as a placeholder for dynamic " + "strides and offsets in shaped types."); + c.def_property_readonly( + "shape", + [](PyShapedType &self) { + self.requireHasRank(); + + std::vector shape; + int64_t rank = mlirShapedTypeGetRank(self); + shape.reserve(rank); + for (int64_t i = 0; i < rank; ++i) + shape.push_back(mlirShapedTypeGetDimSize(self, i)); + return shape; + }, + "Returns the shape of the ranked shaped type as a list of integers."); + c.def_static( + "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); }, + "Returns the value used to indicate dynamic dimensions in shaped " + "types."); + c.def_static( + "get_dynamic_stride_or_offset", + []() { return mlirShapedTypeGetDynamicStrideOrOffset(); }, + "Returns the value used to indicate dynamic strides or offsets in " + "shaped types."); +} -private: - void requireHasRank() { - if (!mlirShapedTypeHasRank(*this)) { - throw py::value_error( - "calling this method requires that the type has a rank."); - } +void mlir::PyShapedType::requireHasRank() { + if (!mlirShapedTypeHasRank(*this)) { + throw py::value_error( + "calling this method requires that the type has a rank."); } -}; +} + +const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction = + mlirTypeIsAShaped; + +namespace { /// Vector Type subclass - VectorType. class PyVectorType : public PyConcreteType { From 8071c1536975e34559796ff101d7a0f9799fbd97 Mon Sep 17 00:00:00 2001 From: PhrygianGates <69254262+PhrygianGates@users.noreply.github.com> Date: Tue, 27 Aug 2024 10:40:38 +0800 Subject: [PATCH 756/915] [MLIR][Python] add f8E5M2 and tests for np_to_memref (#106028) add f8E5M2 and tests for np_to_memref --------- Co-authored-by: Zhicheng Xiong --- mlir/python/mlir/runtime/np_to_memref.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mlir/python/mlir/runtime/np_to_memref.py b/mlir/python/mlir/runtime/np_to_memref.py index 882b27519..8cca1e7ad 100644 --- a/mlir/python/mlir/runtime/np_to_memref.py +++ b/mlir/python/mlir/runtime/np_to_memref.py @@ -37,6 +37,11 @@ class BF16(ctypes.Structure): _fields_ = [("bf16", ctypes.c_int16)] +class F8E5M2(ctypes.Structure): + """A ctype representation for MLIR's Float8E5M2.""" + + _fields_ = [("f8E5M2", ctypes.c_int8)] + # https://stackoverflow.com/questions/26921836/correct-way-to-test-for-numpy-dtype def as_ctype(dtp): @@ -49,6 +54,8 @@ def as_ctype(dtp): return F16 if ml_dtypes is not None and dtp == ml_dtypes.bfloat16: return BF16 + if ml_dtypes is not None and dtp == ml_dtypes.float8_e5m2: + return F8E5M2 return np.ctypeslib.as_ctypes_type(dtp) @@ -65,6 +72,11 @@ def to_numpy(array): ), f"bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" if array.dtype == BF16: return array.view("bfloat16") + assert not ( + array.dtype == F8E5M2 and ml_dtypes is None + ), f"float8_e5m2 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" + if array.dtype == F8E5M2: + return array.view("float8_e5m2") return array From 403afb79572e65b635a4fc96339ff210746c421e Mon Sep 17 00:00:00 2001 From: Abid Qadeer Date: Tue, 27 Aug 2024 11:10:11 +0100 Subject: [PATCH 757/915] [mlir][debug] Handle DIImportedEntity. (#103055) The `DIImporedEntity` can be used to represent imported entities like C++'s namespace with using directive or fortran's moudule with use statement. This PR adds `DIImportedEntityAttr` and 2-way translation from `DIImportedEntity` to `DIImportedEntityAttr` and vice versa. When an entity is imported in a function, the `retainedNodes` field of the `DISubprogram` contains all the imported nodes. See the C++ code and the LLVM IR below. ``` void test() { using namespace n1; ... } !2 = !DINamespace(name: "n1", scope: null) !16 = distinct !DISubprogram(name: "test", ..., retainedNodes: !19) !19 = !{!20} !20 = !DIImportedEntity(tag: DW_TAG_imported_module, scope: !16, entity: !2 ...) ``` This PR makes sure that the translation from mlir to `retainedNodes` field happens correctly both ways. To side step the cyclic dependency between `DISubprogramAttr` and `DIImportedEntityAttr`, we have decided to not have `scope` field in the `DIImportedEntityAttr` and it is inferred from the entity which hold the list of `DIImportedEntityAttr`. A `retainedNodes` field has been added in the `DISubprogramAttr` which contains the list of `DIImportedEntityAttr` for that function. This PR currently does not handle entities imported in a global scope but that should be easy to handle in a subsequent PR. --- mlir/include/mlir-c/Dialect/LLVM.h | 9 ++++++++- mlir/lib/CAPI/Dialect/LLVM.cpp | 23 +++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 631b56461..5eb96a86e 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -316,7 +316,8 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type); + uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, + MlirAttribute const *retainedNodes); /// Gets the scope from this DISubprogramAttr. MLIR_CAPI_EXPORTED MlirAttribute @@ -353,6 +354,12 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( MlirAttribute name, MlirAttribute configMacros, MlirAttribute includePath, MlirAttribute apinotes, unsigned int line, bool isDecl); +/// Creates a LLVM DIImportedEntityAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIImportedEntityAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, + unsigned int line, MlirAttribute name, intptr_t nElements, + MlirAttribute const *elements); + /// Gets the scope of this DIModuleAttr. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule); diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 03e2f2be2..13341f0c4 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -293,14 +293,20 @@ MlirAttribute mlirLLVMDISubprogramAttrGet( MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type) { + uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, + MlirAttribute const *retainedNodes) { + SmallVector nodesStorage; + nodesStorage.reserve(nRetainedNodes); return wrap(DISubprogramAttr::get( unwrap(ctx), cast(unwrap(id)), cast(unwrap(compileUnit)), cast(unwrap(scope)), cast(unwrap(name)), cast(unwrap(linkageName)), cast(unwrap(file)), line, scopeLine, DISubprogramFlags(subprogramFlags), - cast(unwrap(type)))); + cast(unwrap(type)), + llvm::map_to_vector( + unwrapList(nRetainedNodes, retainedNodes, nodesStorage), + [](Attribute a) { return cast(a); }))); } MlirAttribute mlirLLVMDISubprogramAttrGetScope(MlirAttribute diSubprogram) { @@ -345,3 +351,16 @@ MlirAttribute mlirLLVMDIModuleAttrGet(MlirContext ctx, MlirAttribute file, MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { return wrap(cast(unwrap(diModule)).getScope()); } + +MlirAttribute mlirLLVMDIImportedEntityAttrGet( + MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, + unsigned int line, MlirAttribute name, intptr_t nElements, + MlirAttribute const *elements) { + SmallVector elementsStorage; + elementsStorage.reserve(nElements); + return wrap(DIImportedEntityAttr::get( + unwrap(ctx), tag, cast(unwrap(entity)), + cast(unwrap(file)), line, cast(unwrap(name)), + llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), + [](Attribute a) { return cast(a); }))); +} From 663d99079a2f88d472d099c646eae1dcfd1eba5b Mon Sep 17 00:00:00 2001 From: Fabian Mora Date: Tue, 27 Aug 2024 18:44:50 -0400 Subject: [PATCH 758/915] [mlir][gpu] Add metadata attributes for storing kernel metadata in GPU objects (#95292) This patch adds the `#gpu.kernel_metadata` and `#gpu.kernel_table` attributes. The `#gpu.kernel_metadata` attribute allows storing metadata related to a compiled kernel, for example, the number of scalar registers used by the kernel. The attribute only has 2 required parameters, the name and function type. It also has 2 optional parameters, the arguments attributes and generic dictionary for storing all other metadata. The `#gpu.kernel_table` stores a table of `#gpu.kernel_metadata`, mapping the name of the kernel to the metadata. Finally, the function `ROCDL::getAMDHSAKernelsELFMetadata` was added to collect ELF metadata from a binary, and to test the class methods in both attributes. Example: ```mlir gpu.binary @binary [#gpu.object<#rocdl.target, kernels = #gpu.kernel_table<[ #gpu.kernel_metadata<"kernel0", (i32) -> (), metadata = {sgpr_count = 255}>, #gpu.kernel_metadata<"kernel1", (i32, f32) -> (), arg_attrs = [{llvm.read_only}, {}]> ]> , bin = "BLOB">] ``` The motivation behind these attributes is to provide useful information for things like tunning. --------- Co-authored-by: Mehdi Amini --- mlir/include/mlir-c/Dialect/GPU.h | 11 ++++++++ mlir/lib/Bindings/Python/DialectGPU.cpp | 23 +++++++++++---- mlir/lib/CAPI/Dialect/GPU.cpp | 37 +++++++++++++++++++++++-- 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/GPU.h b/mlir/include/mlir-c/Dialect/GPU.h index c42ff61f9..321c1122c 100644 --- a/mlir/include/mlir-c/Dialect/GPU.h +++ b/mlir/include/mlir-c/Dialect/GPU.h @@ -37,6 +37,11 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, uint32_t format, MlirStringRef objectStrRef, MlirAttribute mlirObjectProps); +MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetWithKernels( + MlirContext mlirCtx, MlirAttribute target, uint32_t format, + MlirStringRef objectStrRef, MlirAttribute mlirObjectProps, + MlirAttribute mlirKernelsAttr); + MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr); @@ -52,6 +57,12 @@ mlirGPUObjectAttrHasProperties(MlirAttribute mlirObjectAttr); MLIR_CAPI_EXPORTED MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr); +MLIR_CAPI_EXPORTED bool +mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index a9e339b50..560a54bcd 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -48,17 +48,21 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { .def_classmethod( "get", [](py::object cls, MlirAttribute target, uint32_t format, - py::bytes object, std::optional mlirObjectProps) { + py::bytes object, std::optional mlirObjectProps, + std::optional mlirKernelsAttr) { py::buffer_info info(py::buffer(object).request()); MlirStringRef objectStrRef = mlirStringRefCreate(static_cast(info.ptr), info.size); - return cls(mlirGPUObjectAttrGet( + return cls(mlirGPUObjectAttrGetWithKernels( mlirAttributeGetContext(target), target, format, objectStrRef, mlirObjectProps.has_value() ? *mlirObjectProps + : MlirAttribute{nullptr}, + mlirKernelsAttr.has_value() ? *mlirKernelsAttr : MlirAttribute{nullptr})); }, "cls"_a, "target"_a, "format"_a, "object"_a, - "properties"_a = py::none(), "Gets a gpu.object from parameters.") + "properties"_a = py::none(), "kernels"_a = py::none(), + "Gets a gpu.object from parameters.") .def_property_readonly( "target", [](MlirAttribute self) { return mlirGPUObjectAttrGetTarget(self); }) @@ -71,9 +75,16 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); return py::bytes(stringRef.data, stringRef.length); }) - .def_property_readonly("properties", [](MlirAttribute self) { - if (mlirGPUObjectAttrHasProperties(self)) - return py::cast(mlirGPUObjectAttrGetProperties(self)); + .def_property_readonly("properties", + [](MlirAttribute self) { + if (mlirGPUObjectAttrHasProperties(self)) + return py::cast( + mlirGPUObjectAttrGetProperties(self)); + return py::none().cast(); + }) + .def_property_readonly("kernels", [](MlirAttribute self) { + if (mlirGPUObjectAttrHasKernels(self)) + return py::cast(mlirGPUObjectAttrGetKernels(self)); return py::none().cast(); }); } diff --git a/mlir/lib/CAPI/Dialect/GPU.cpp b/mlir/lib/CAPI/Dialect/GPU.cpp index 0acebb230..e4796ed14 100644 --- a/mlir/lib/CAPI/Dialect/GPU.cpp +++ b/mlir/lib/CAPI/Dialect/GPU.cpp @@ -43,9 +43,28 @@ MlirAttribute mlirGPUObjectAttrGet(MlirContext mlirCtx, MlirAttribute target, DictionaryAttr objectProps; if (mlirObjectProps.ptr != nullptr) objectProps = llvm::cast(unwrap(mlirObjectProps)); - return wrap(gpu::ObjectAttr::get(ctx, unwrap(target), - static_cast(format), - StringAttr::get(ctx, object), objectProps)); + return wrap(gpu::ObjectAttr::get( + ctx, unwrap(target), static_cast(format), + StringAttr::get(ctx, object), objectProps, nullptr)); +} + +MlirAttribute mlirGPUObjectAttrGetWithKernels(MlirContext mlirCtx, + MlirAttribute target, + uint32_t format, + MlirStringRef objectStrRef, + MlirAttribute mlirObjectProps, + MlirAttribute mlirKernelsAttr) { + MLIRContext *ctx = unwrap(mlirCtx); + llvm::StringRef object = unwrap(objectStrRef); + DictionaryAttr objectProps; + if (mlirObjectProps.ptr != nullptr) + objectProps = llvm::cast(unwrap(mlirObjectProps)); + gpu::KernelTableAttr kernels; + if (mlirKernelsAttr.ptr != nullptr) + kernels = llvm::cast(unwrap(mlirKernelsAttr)); + return wrap(gpu::ObjectAttr::get( + ctx, unwrap(target), static_cast(format), + StringAttr::get(ctx, object), objectProps, kernels)); } MlirAttribute mlirGPUObjectAttrGetTarget(MlirAttribute mlirObjectAttr) { @@ -78,3 +97,15 @@ MlirAttribute mlirGPUObjectAttrGetProperties(MlirAttribute mlirObjectAttr) { llvm::cast(unwrap(mlirObjectAttr)); return wrap(objectAttr.getProperties()); } + +bool mlirGPUObjectAttrHasKernels(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return objectAttr.getKernels() != nullptr; +} + +MlirAttribute mlirGPUObjectAttrGetKernels(MlirAttribute mlirObjectAttr) { + gpu::ObjectAttr objectAttr = + llvm::cast(unwrap(mlirObjectAttr)); + return wrap(objectAttr.getKernels()); +} From 38979255701bcfe4af4a0d2ecb70f89939f799d8 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Sat, 31 Aug 2024 07:23:24 +0200 Subject: [PATCH 759/915] [MLIR][LLVM] Make DISubprogramAttr cyclic (#106571) This commit implements LLVM_DIRecursiveTypeAttrInterface for the DISubprogramAttr to ensure cyclic subprograms can be imported properly. In the process multiple shortcuts around the recently introduced DIImportedEntityAttr can be removed. --- mlir/include/mlir-c/Dialect/LLVM.h | 26 ++++++++++++-------- mlir/lib/CAPI/Dialect/LLVM.cpp | 39 +++++++++++++++++++----------- 2 files changed, 41 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 5eb96a86e..38bd4d2f3 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -234,10 +234,13 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, MlirLLVMTypeEncoding encoding); +/// Creates a self-referencing LLVM DICompositeType attribute. +MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId); + /// Creates a LLVM DICompositeType attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, - MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -311,13 +314,16 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( MlirAttribute diFile, unsigned int line, unsigned int arg, unsigned int alignInBits, MlirAttribute diType, int64_t flags); +/// Creates a self-referencing LLVM DISubprogramAttr attribute. +MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId); + /// Creates a LLVM DISubprogramAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, - MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, - MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, - MlirAttribute const *retainedNodes); + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes); /// Gets the scope from this DISubprogramAttr. MLIR_CAPI_EXPORTED MlirAttribute @@ -356,9 +362,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( /// Creates a LLVM DIImportedEntityAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, - unsigned int line, MlirAttribute name, intptr_t nElements, - MlirAttribute const *elements); + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements); /// Gets the scope of this DIModuleAttr. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 13341f0c4..03b536d7a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -159,9 +159,14 @@ MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, encoding)); } +MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId) { + return wrap( + DICompositeTypeAttr::getRecSelf(cast(unwrap(recId)))); +} + MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, - MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -170,7 +175,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( elementsStorage.reserve(nElements); return wrap(DICompositeTypeAttr::get( - unwrap(ctx), tag, cast(unwrap(recId)), + unwrap(ctx), cast(unwrap(recId)), isRecSelf, tag, cast(unwrap(name)), cast(unwrap(file)), line, cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, @@ -289,16 +294,21 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, [](Attribute a) { return cast(a); }))); } +MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { + return wrap(DISubprogramAttr::getRecSelf(cast(unwrap(recId)))); +} + MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, - MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, - MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, - MlirAttribute const *retainedNodes) { + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes) { SmallVector nodesStorage; nodesStorage.reserve(nRetainedNodes); return wrap(DISubprogramAttr::get( - unwrap(ctx), cast(unwrap(id)), + unwrap(ctx), cast(unwrap(recId)), isRecSelf, + cast(unwrap(id)), cast(unwrap(compileUnit)), cast(unwrap(scope)), cast(unwrap(name)), cast(unwrap(linkageName)), cast(unwrap(file)), @@ -353,14 +363,15 @@ MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { } MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, - unsigned int line, MlirAttribute name, intptr_t nElements, - MlirAttribute const *elements) { + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements) { SmallVector elementsStorage; elementsStorage.reserve(nElements); return wrap(DIImportedEntityAttr::get( - unwrap(ctx), tag, cast(unwrap(entity)), - cast(unwrap(file)), line, cast(unwrap(name)), + unwrap(ctx), tag, cast(unwrap(scope)), + cast(unwrap(entity)), cast(unwrap(file)), line, + cast(unwrap(name)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), [](Attribute a) { return cast(a); }))); } From 371ad453a289edf86fb22bace1ec0be754591754 Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Sat, 31 Aug 2024 07:51:53 +0200 Subject: [PATCH 760/915] Revert "[MLIR][LLVM] Make DISubprogramAttr cyclic" (#106827) Reverts llvm/llvm-project#106571 This commit breaks the following build bot: https://lab.llvm.org/buildbot/#/builders/138/builds/2992 It looks like there is a missing dependency in this particular setup. --- mlir/include/mlir-c/Dialect/LLVM.h | 26 ++++++++------------ mlir/lib/CAPI/Dialect/LLVM.cpp | 39 +++++++++++------------------- 2 files changed, 24 insertions(+), 41 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 38bd4d2f3..5eb96a86e 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -234,13 +234,10 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, MlirLLVMTypeEncoding encoding); -/// Creates a self-referencing LLVM DICompositeType attribute. -MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId); - /// Creates a LLVM DICompositeType attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, - MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, + MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -314,16 +311,13 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( MlirAttribute diFile, unsigned int line, unsigned int arg, unsigned int alignInBits, MlirAttribute diType, int64_t flags); -/// Creates a self-referencing LLVM DISubprogramAttr attribute. -MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId); - /// Creates a LLVM DISubprogramAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, - MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, - MlirAttribute linkageName, MlirAttribute file, unsigned int line, - unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, - intptr_t nRetainedNodes, MlirAttribute const *retainedNodes); + MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, + MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, + MlirAttribute file, unsigned int line, unsigned int scopeLine, + uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, + MlirAttribute const *retainedNodes); /// Gets the scope from this DISubprogramAttr. MLIR_CAPI_EXPORTED MlirAttribute @@ -362,9 +356,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( /// Creates a LLVM DIImportedEntityAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute scope, - MlirAttribute entity, MlirAttribute file, unsigned int line, - MlirAttribute name, intptr_t nElements, MlirAttribute const *elements); + MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, + unsigned int line, MlirAttribute name, intptr_t nElements, + MlirAttribute const *elements); /// Gets the scope of this DIModuleAttr. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 03b536d7a..13341f0c4 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -159,14 +159,9 @@ MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, encoding)); } -MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId) { - return wrap( - DICompositeTypeAttr::getRecSelf(cast(unwrap(recId)))); -} - MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, - MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, + MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -175,7 +170,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( elementsStorage.reserve(nElements); return wrap(DICompositeTypeAttr::get( - unwrap(ctx), cast(unwrap(recId)), isRecSelf, tag, + unwrap(ctx), tag, cast(unwrap(recId)), cast(unwrap(name)), cast(unwrap(file)), line, cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, @@ -294,21 +289,16 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, [](Attribute a) { return cast(a); }))); } -MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { - return wrap(DISubprogramAttr::getRecSelf(cast(unwrap(recId)))); -} - MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, - MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, - MlirAttribute linkageName, MlirAttribute file, unsigned int line, - unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, - intptr_t nRetainedNodes, MlirAttribute const *retainedNodes) { + MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, + MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, + MlirAttribute file, unsigned int line, unsigned int scopeLine, + uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, + MlirAttribute const *retainedNodes) { SmallVector nodesStorage; nodesStorage.reserve(nRetainedNodes); return wrap(DISubprogramAttr::get( - unwrap(ctx), cast(unwrap(recId)), isRecSelf, - cast(unwrap(id)), + unwrap(ctx), cast(unwrap(id)), cast(unwrap(compileUnit)), cast(unwrap(scope)), cast(unwrap(name)), cast(unwrap(linkageName)), cast(unwrap(file)), @@ -363,15 +353,14 @@ MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { } MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute scope, - MlirAttribute entity, MlirAttribute file, unsigned int line, - MlirAttribute name, intptr_t nElements, MlirAttribute const *elements) { + MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, + unsigned int line, MlirAttribute name, intptr_t nElements, + MlirAttribute const *elements) { SmallVector elementsStorage; elementsStorage.reserve(nElements); return wrap(DIImportedEntityAttr::get( - unwrap(ctx), tag, cast(unwrap(scope)), - cast(unwrap(entity)), cast(unwrap(file)), line, - cast(unwrap(name)), + unwrap(ctx), tag, cast(unwrap(entity)), + cast(unwrap(file)), line, cast(unwrap(name)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), [](Attribute a) { return cast(a); }))); } From 98225dc68fa975b03d21029711919529de70da0b Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sat, 31 Aug 2024 09:17:33 +0200 Subject: [PATCH 761/915] [mlir][python] Fix how the mlir variadic Python accessor `_ods_equally_sized_accessor` is used (#101132) (#106003) As reported in https://github.com/llvm/llvm-project/issues/101132, this fixes two bugs: 1. When accessing variadic operands inside an operation, it must be accessed as `self.operation.operands` instead of `operation.operands` 2. The implementation of the `equally_sized_accessor` function is doing wrong arithmetics when calculating the resulting index and group sizes. I have added a test for the `equally_sized_accessor` function, which did not have a test previously. --- mlir/python/mlir/dialects/_ods_common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 1e7e8244e..d40d936cd 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -51,13 +51,14 @@ def segmented_accessor(elements, raw_segments, idx): def equally_sized_accessor( - elements, n_variadic, n_preceding_simple, n_preceding_variadic + elements, n_simple, n_variadic, n_preceding_simple, n_preceding_variadic ): """ Returns a starting position and a number of elements per variadic group assuming equally-sized groups and the given numbers of preceding groups. elements: a sequential container. + n_simple: the number of non-variadic groups in the container. n_variadic: the number of variadic groups in the container. n_preceding_simple: the number of non-variadic groups preceding the current group. @@ -65,7 +66,7 @@ def equally_sized_accessor( group. """ - total_variadic_length = len(elements) - n_variadic + 1 + total_variadic_length = len(elements) - n_simple # This should be enforced by the C++-side trait verifier. assert total_variadic_length % n_variadic == 0 From f26ada9b29cdbee7b69ebfbec1df308ea3f779cf Mon Sep 17 00:00:00 2001 From: Tobias Gysi Date: Mon, 2 Sep 2024 12:26:15 +0200 Subject: [PATCH 762/915] Reapply "[MLIR][LLVM] Make DISubprogramAttr cyclic" (#106571) with fixes (#106947) This reverts commit 371ad45, restoring commit 3897925, with fixes that ensure the CAPI declarations are exported properly. This commit implements LLVM_DIRecursiveTypeAttrInterface for the DISubprogramAttr to ensure cyclic subprograms can be imported properly. In the process multiple shortcuts around the recently introduced DIImportedEntityAttr can be removed. --- mlir/include/mlir-c/Dialect/LLVM.h | 28 +++++++++++++-------- mlir/lib/CAPI/Dialect/LLVM.cpp | 39 +++++++++++++++++++----------- 2 files changed, 43 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 5eb96a86e..d6062bed5 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -234,10 +234,14 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIBasicTypeAttrGet( MlirContext ctx, unsigned int tag, MlirAttribute name, uint64_t sizeInBits, MlirLLVMTypeEncoding encoding); +/// Creates a self-referencing LLVM DICompositeType attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId); + /// Creates a LLVM DICompositeType attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, - MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -311,13 +315,17 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDILocalVariableAttrGet( MlirAttribute diFile, unsigned int line, unsigned int arg, unsigned int alignInBits, MlirAttribute diType, int64_t flags); +/// Creates a self-referencing LLVM DISubprogramAttr attribute. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId); + /// Creates a LLVM DISubprogramAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, - MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, - MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, - MlirAttribute const *retainedNodes); + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes); /// Gets the scope from this DISubprogramAttr. MLIR_CAPI_EXPORTED MlirAttribute @@ -356,9 +364,9 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIModuleAttrGet( /// Creates a LLVM DIImportedEntityAttr attribute. MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, - unsigned int line, MlirAttribute name, intptr_t nElements, - MlirAttribute const *elements); + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements); /// Gets the scope of this DIModuleAttr. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 13341f0c4..03b536d7a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -159,9 +159,14 @@ MlirAttribute mlirLLVMDIBasicTypeAttrGet(MlirContext ctx, unsigned int tag, unwrap(ctx), tag, cast(unwrap(name)), sizeInBits, encoding)); } +MlirAttribute mlirLLVMDICompositeTypeAttrGetRecSelf(MlirAttribute recId) { + return wrap( + DICompositeTypeAttr::getRecSelf(cast(unwrap(recId)))); +} + MlirAttribute mlirLLVMDICompositeTypeAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute recId, MlirAttribute name, - MlirAttribute file, uint32_t line, MlirAttribute scope, + MlirContext ctx, MlirAttribute recId, bool isRecSelf, unsigned int tag, + MlirAttribute name, MlirAttribute file, uint32_t line, MlirAttribute scope, MlirAttribute baseType, int64_t flags, uint64_t sizeInBits, uint64_t alignInBits, intptr_t nElements, MlirAttribute const *elements, MlirAttribute dataLocation, MlirAttribute rank, MlirAttribute allocated, @@ -170,7 +175,7 @@ MlirAttribute mlirLLVMDICompositeTypeAttrGet( elementsStorage.reserve(nElements); return wrap(DICompositeTypeAttr::get( - unwrap(ctx), tag, cast(unwrap(recId)), + unwrap(ctx), cast(unwrap(recId)), isRecSelf, tag, cast(unwrap(name)), cast(unwrap(file)), line, cast(unwrap(scope)), cast(unwrap(baseType)), DIFlags(flags), sizeInBits, alignInBits, @@ -289,16 +294,21 @@ MlirAttribute mlirLLVMDISubroutineTypeAttrGet(MlirContext ctx, [](Attribute a) { return cast(a); }))); } +MlirAttribute mlirLLVMDISubprogramAttrGetRecSelf(MlirAttribute recId) { + return wrap(DISubprogramAttr::getRecSelf(cast(unwrap(recId)))); +} + MlirAttribute mlirLLVMDISubprogramAttrGet( - MlirContext ctx, MlirAttribute id, MlirAttribute compileUnit, - MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, - MlirAttribute file, unsigned int line, unsigned int scopeLine, - uint64_t subprogramFlags, MlirAttribute type, intptr_t nRetainedNodes, - MlirAttribute const *retainedNodes) { + MlirContext ctx, MlirAttribute recId, bool isRecSelf, MlirAttribute id, + MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, + MlirAttribute linkageName, MlirAttribute file, unsigned int line, + unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes) { SmallVector nodesStorage; nodesStorage.reserve(nRetainedNodes); return wrap(DISubprogramAttr::get( - unwrap(ctx), cast(unwrap(id)), + unwrap(ctx), cast(unwrap(recId)), isRecSelf, + cast(unwrap(id)), cast(unwrap(compileUnit)), cast(unwrap(scope)), cast(unwrap(name)), cast(unwrap(linkageName)), cast(unwrap(file)), @@ -353,14 +363,15 @@ MlirAttribute mlirLLVMDIModuleAttrGetScope(MlirAttribute diModule) { } MlirAttribute mlirLLVMDIImportedEntityAttrGet( - MlirContext ctx, unsigned int tag, MlirAttribute entity, MlirAttribute file, - unsigned int line, MlirAttribute name, intptr_t nElements, - MlirAttribute const *elements) { + MlirContext ctx, unsigned int tag, MlirAttribute scope, + MlirAttribute entity, MlirAttribute file, unsigned int line, + MlirAttribute name, intptr_t nElements, MlirAttribute const *elements) { SmallVector elementsStorage; elementsStorage.reserve(nElements); return wrap(DIImportedEntityAttr::get( - unwrap(ctx), tag, cast(unwrap(entity)), - cast(unwrap(file)), line, cast(unwrap(name)), + unwrap(ctx), tag, cast(unwrap(scope)), + cast(unwrap(entity)), cast(unwrap(file)), line, + cast(unwrap(name)), llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), [](Attribute a) { return cast(a); }))); } From d7b0d723c0d45902e831953f1ef89ef6706464ae Mon Sep 17 00:00:00 2001 From: Matt Hofmann Date: Thu, 5 Sep 2024 00:12:03 -0400 Subject: [PATCH 763/915] [MLIR][Python] Fix detached operation coming from `IfOp` constructor (#107286) Without this fix, `scf.if` operations would be created without a parent. Since `scf.if` operations often have no results, this caused silent bugs where the generated code was straight-up missing the operation. --- mlir/python/mlir/dialects/scf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 7025f6e0f..2d0047b76 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -87,7 +87,7 @@ def __init__(self, cond, results_=None, *, hasElse=False, loc=None, ip=None): operands.append(cond) results = [] results.extend(results_) - super().__init__(results, cond) + super().__init__(results, cond, loc=loc, ip=ip) self.regions[0].blocks.append(*[]) if hasElse: self.regions[1].blocks.append(*[]) From bec4b3a14add5f296755677ee94b684ba6c8643d Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 10 Sep 2024 10:41:05 +0200 Subject: [PATCH 764/915] [MLIR] Add f6E3M2FN type (#105573) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f6E3M2FN` type to mlir. `f6E3M2FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f6E3M2FN - Exponent bias: 3 - Maximum stored exponent value: 7 (binary 111) - Maximum unbiased exponent value: 7 - 3 = 4 - Minimum stored exponent value: 1 (binary 001) - Minimum unbiased exponent value: 1 − 3 = −2 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.000.00 - Max normal number: S.111.11 = ±2^(4) x (1 + 0.75) = ±28 - Min normal number: S.001.00 = ±2^(-2) = ±0.25 - Max subnormal number: S.000.11 = ±2^(-2) x 0.75 = ±0.1875 - Min subnormal number: S.000.01 = ±2^(-2) x 0.25 = ±0.0625 ``` Related PRs: - [PR-94735](https://github.com/llvm/llvm-project/pull/94735) [APFloat] Add APFloat support for FP6 data types - [PR-97118](https://github.com/llvm/llvm-project/pull/97118) [MLIR] Add f8E4M3 type - was used as a template for this PR --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 22 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 60 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index d698bf476..24531baec 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); /// Returns the bitwidth of a floating-point type. MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); +/// Returns the typeID of an Float6E3M2FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void); + +/// Checks whether the given type is an f6E3M2FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E3M2FN(MlirType type); + +/// Creates an f6E3M2FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx); + /// Returns the typeID of an Float8E5M2 type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E5M2TypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 2ee1d89c3..1cb429d9c 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType { } }; +/// Floating Point Type subclass - Float6E3M2FNType. +class PyFloat6E3M2FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E3M2FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E3M2FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); + return PyFloat6E3M2FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float6_e3m2fn type."); + } +}; + /// Floating Point Type subclass - Float8E4M3FNType. class PyFloat8E4M3FNType : public PyConcreteType { @@ -880,6 +901,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); + PyFloat6E3M2FNType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); PyFloat8E4M3Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 2aa2e922f..254650d66 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) { return llvm::cast(unwrap(type)).getWidth(); } +MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { + return wrap(Float6E3M2FNType::getTypeID()); +} + +bool mlirTypeIsAFloat6E3M2FN(MlirType type) { + return unwrap(type).isFloat6E3M2FN(); +} + +MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat6E3M2FN(unwrap(ctx))); +} + MlirTypeID mlirFloat8E5M2TypeGetTypeID() { return wrap(Float8E5M2Type::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 4a2d0e977..7b4fac727 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -120,6 +120,7 @@ __all__ = [ "F32Type", "F64Type", "FlatSymbolRefAttr", + "Float6E3M2FNType", "Float8E3M4Type", "Float8E4M3B11FNUZType", "Float8E4M3FNType", @@ -1539,6 +1540,19 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ +class Float6E3M2FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Optional[Context] = None) -> Float6E3M2FNType: + """ + Create a float6_e3m2fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class Float8E3M4Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index fe7c3e25d..0c6ece91d 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -12,6 +12,7 @@ F16Type, F32Type, F64Type, + Float6E3M2FNType, Float8E3M4Type, Float8E4M3B11FNUZType, Float8E4M3FNType, @@ -74,6 +75,7 @@ def ui(width): f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() f8E3M4 = lambda: Float8E3M4Type.get() +f6E3M2FN = lambda: Float6E3M2FNType.get() none = lambda: NoneType.get() From 5127d61689b54a4d25e597fdcd437ada174cc9c2 Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Wed, 11 Sep 2024 07:37:35 -0400 Subject: [PATCH 765/915] [MLIR][Python] Python binding support for IntegerSet attribute (#107640) Support IntegerSet attribute python binding. --- mlir/include/mlir-c/BuiltinAttributes.h | 9 +++++++++ mlir/lib/Bindings/Python/IRAttributes.cpp | 22 +++++++++++++++++++++- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 9 +++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 16 ++++++++++++++++ mlir/python/mlir/ir.py | 5 +++++ 5 files changed, 60 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 231eb83b5..7c8c84e55 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -16,6 +16,7 @@ #include "mlir-c/AffineMap.h" #include "mlir-c/IR.h" +#include "mlir-c/IntegerSet.h" #include "mlir-c/Support.h" #ifdef __cplusplus @@ -177,6 +178,14 @@ MLIR_CAPI_EXPORTED bool mlirBoolAttrGetValue(MlirAttribute attr); /// Checks whether the given attribute is an integer set attribute. MLIR_CAPI_EXPORTED bool mlirAttributeIsAIntegerSet(MlirAttribute attr); +/// Creates an integer set attribute wrapping the given set. The attribute +/// belongs to the same context as the integer set. +MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set); + +/// Returns the integer set wrapped in the given integer set attribute. +MLIR_CAPI_EXPORTED MlirIntegerSet +mlirIntegerSetAttrGetValue(MlirAttribute attr); + /// Returns the typeID of an IntegerSet attribute. MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerSetAttrGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index b4049bd79..bfdd4a520 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -147,6 +147,26 @@ class PyAffineMapAttribute : public PyConcreteAttribute { } }; +class PyIntegerSetAttribute + : public PyConcreteAttribute { +public: + static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAIntegerSet; + static constexpr const char *pyClassName = "IntegerSetAttr"; + using PyConcreteAttribute::PyConcreteAttribute; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerSetAttrGetTypeID; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyIntegerSet &integerSet) { + MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); + return PyIntegerSetAttribute(integerSet.getContext(), attr); + }, + py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + } +}; + template static T pyTryCast(py::handle object) { try { @@ -1426,7 +1446,6 @@ py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { void mlir::python::populateIRAttributes(py::module &m) { PyAffineMapAttribute::bind(m); - PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); PyDenseI8ArrayAttribute::bind(m); @@ -1466,6 +1485,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyOpaqueAttribute::bind(m); PyFloatAttribute::bind(m); PyIntegerAttribute::bind(m); + PyIntegerSetAttribute::bind(m); PyStringAttribute::bind(m); PyTypeAttribute::bind(m); PyGlobals::get().registerTypeCaster( diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 726af8846..11d1ade55 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -10,6 +10,7 @@ #include "mlir-c/Support.h" #include "mlir/CAPI/AffineMap.h" #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/IntegerSet.h" #include "mlir/CAPI/Support.h" #include "mlir/IR/AsmState.h" #include "mlir/IR/Attributes.h" @@ -192,6 +193,14 @@ MlirTypeID mlirIntegerSetAttrGetTypeID(void) { return wrap(IntegerSetAttr::getTypeID()); } +MlirAttribute mlirIntegerSetAttrGet(MlirIntegerSet set) { + return wrap(IntegerSetAttr::get(unwrap(set))); +} + +MlirIntegerSet mlirIntegerSetAttrGetValue(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getValue()); +} + //===----------------------------------------------------------------------===// // Opaque attribute. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 7b4fac727..a3d3a9261 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -138,6 +138,7 @@ __all__ = [ "InsertionPoint", "IntegerAttr", "IntegerSet", + "IntegerSetAttr", "IntegerSetConstraint", "IntegerSetConstraintList", "IntegerType", @@ -1905,6 +1906,21 @@ class IntegerSet: @property def n_symbols(self) -> int: ... +class IntegerSetAttr(Attribute): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(integer_set) -> IntegerSetAttr: + """ + Gets an attribute wrapping an IntegerSet. + """ + @staticmethod + def isinstance(other: Attribute) -> bool: ... + def __init__(self, cast_from_attr: Attribute) -> None: ... + @property + def type(self) -> Type: ... + @property + def typeid(self) -> TypeID: ... + class IntegerSetConstraint: def __init__(self, *args, **kwargs) -> None: ... @property diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index a9ac765fe..9a6ce4620 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -22,6 +22,11 @@ def _affineMapAttr(x, context): return AffineMapAttr.get(x) +@register_attribute_builder("IntegerSetAttr") +def _integerSetAttr(x, context): + return IntegerSetAttr.get(x) + + @register_attribute_builder("BoolAttr") def _boolAttr(x, context): return BoolAttr.get(x, context=context) From 5e836381e018bd211d5ceba56b42915ab10794e0 Mon Sep 17 00:00:00 2001 From: JOE1994 Date: Sun, 15 Sep 2024 21:12:09 -0400 Subject: [PATCH 766/915] [mlir] Nits on uses of llvm::raw_string_ostream (NFC) * Strip calls to raw_string_ostream::flush(), which is essentially a no-op * Strip unneeded calls to raw_string_ostream::str(), to avoid excess indirection. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index bfdd4a520..ead81a76c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -708,7 +708,7 @@ class PyDenseElementsAttribute llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " << py::repr(py::cast(*explicitType)); - throw py::value_error(os.str()); + throw py::value_error(message); } shapedType = *explicitType; } else { @@ -732,7 +732,7 @@ class PyDenseElementsAttribute os << "All attributes must be of the same type and match " << "the type parameter: expected=" << py::repr(py::cast(shapedType)) << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(os.str()); + throw py::value_error(message); } } From 43bfc4098faefac1886b8b796182cea09f841a2e Mon Sep 17 00:00:00 2001 From: JOE1994 Date: Sun, 15 Sep 2024 22:08:32 -0400 Subject: [PATCH 767/915] Revert "[mlir] Nits on uses of llvm::raw_string_ostream (NFC)" This reverts commit 5e836381e018bd211d5ceba56b42915ab10794e0. "FAIL: MLIR::completion.test" on multiple buildbots. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index ead81a76c..bfdd4a520 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -708,7 +708,7 @@ class PyDenseElementsAttribute llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " << py::repr(py::cast(*explicitType)); - throw py::value_error(message); + throw py::value_error(os.str()); } shapedType = *explicitType; } else { @@ -732,7 +732,7 @@ class PyDenseElementsAttribute os << "All attributes must be of the same type and match " << "the type parameter: expected=" << py::repr(py::cast(shapedType)) << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(message); + throw py::value_error(os.str()); } } From f05b24fc72be22fc354f78699b5800a7bc22d775 Mon Sep 17 00:00:00 2001 From: JOE1994 Date: Sun, 15 Sep 2024 21:12:09 -0400 Subject: [PATCH 768/915] [mlir] Reland 5e836381e018bd211d5ceba56b42915ab10794e0 with update (NFC) Excluded updates to mlir/lib/AsmParser/Parser.cpp , which caused LIT failure "FAIL: MLIR::completion.test" on multiple buildbots. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index bfdd4a520..ead81a76c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -708,7 +708,7 @@ class PyDenseElementsAttribute llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " << py::repr(py::cast(*explicitType)); - throw py::value_error(os.str()); + throw py::value_error(message); } shapedType = *explicitType; } else { @@ -732,7 +732,7 @@ class PyDenseElementsAttribute os << "All attributes must be of the same type and match " << "the type parameter: expected=" << py::repr(py::cast(shapedType)) << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(os.str()); + throw py::value_error(message); } } From e5597e38afa8b2c51f0028844ef6dc4f5bf434eb Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Mon, 16 Sep 2024 21:09:27 +0200 Subject: [PATCH 769/915] [MLIR] Add f6E2M3FN type (#107999) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f6E2M3FN` type to mlir. `f6E2M3FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 6-bit floating point number with bit layout S1E2M3. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f6E2M3FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.000 - Max normal number: S.11.111 = ±2^(2) x (1 + 0.875) = ±7.5 - Min normal number: S.01.000 = ±2^(0) = ±1.0 - Max subnormal number: S.00.111 = ±2^(0) x 0.875 = ±0.875 - Min subnormal number: S.00.001 = ±2^(0) x 0.125 = ±0.125 ``` Related PRs: - [PR-94735](https://github.com/llvm/llvm-project/pull/94735) [APFloat] Add APFloat support for FP6 data types - [PR-105573](https://github.com/llvm/llvm-project/pull/105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 22 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 60 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 24531baec..cc6da482a 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); /// Returns the bitwidth of a floating-point type. MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); +/// Returns the typeID of an Float6E2M3FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void); + +/// Checks whether the given type is an f6E2M3FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat6E2M3FN(MlirType type); + +/// Creates an f6E2M3FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx); + /// Returns the typeID of an Float6E3M2FN type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E3M2FNTypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 1cb429d9c..6b64bc3c9 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType { } }; +/// Floating Point Type subclass - Float6E2M3FNType. +class PyFloat6E2M3FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E2M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E2M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); + return PyFloat6E2M3FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float6_e2m3fn type."); + } +}; + /// Floating Point Type subclass - Float6E3M2FNType. class PyFloat6E3M2FNType : public PyConcreteType { @@ -901,6 +922,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); + PyFloat6E2M3FNType::bind(m); PyFloat6E3M2FNType::bind(m); PyFloat8E4M3FNType::bind(m); PyFloat8E5M2Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 254650d66..f943bf726 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) { return llvm::cast(unwrap(type)).getWidth(); } +MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { + return wrap(Float6E2M3FNType::getTypeID()); +} + +bool mlirTypeIsAFloat6E2M3FN(MlirType type) { + return unwrap(type).isFloat6E2M3FN(); +} + +MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx))); +} + MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { return wrap(Float6E3M2FNType::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index a3d3a9261..ea5c96dcb 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -120,6 +120,7 @@ __all__ = [ "F32Type", "F64Type", "FlatSymbolRefAttr", + "Float6E2M3FNType", "Float6E3M2FNType", "Float8E3M4Type", "Float8E4M3B11FNUZType", @@ -1541,6 +1542,19 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ +class Float6E2M3FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Optional[Context] = None) -> Float6E2M3FNType: + """ + Create a float6_e2m3fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class Float6E3M2FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index 0c6ece91d..4be425f22 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -12,6 +12,7 @@ F16Type, F32Type, F64Type, + Float6E2M3FNType, Float6E3M2FNType, Float8E3M4Type, Float8E4M3B11FNUZType, @@ -75,6 +76,7 @@ def ui(width): f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() f8E3M4 = lambda: Float8E3M4Type.get() +f6E2M3FN = lambda: Float6E2M3FNType.get() f6E3M2FN = lambda: Float6E3M2FNType.get() none = lambda: NoneType.get() From 4acefdd52e681ed0ce2d6559ae37fe711fb11d87 Mon Sep 17 00:00:00 2001 From: Bimo Date: Wed, 18 Sep 2024 11:54:16 +0800 Subject: [PATCH 770/915] [MLIR] [Python] align python ir printing with mlir-print-ir-after-all (#107522) When using the `enable_ir_printing` API from Python, it invokes IR printing with default args, printing the IR before each pass and printing IR after pass only if there have been changes. This PR attempts to align the `enable_ir_printing` API with the documentation --- mlir/include/mlir-c/Pass.h | 8 +++++--- mlir/lib/Bindings/Python/Pass.cpp | 13 ++++++++++--- mlir/lib/CAPI/IR/Pass.cpp | 17 +++++++++++++++-- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 9 ++++++++- 4 files changed, 38 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 35db13830..2218ec0f4 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -74,9 +74,11 @@ mlirPassManagerGetAsOpPassManager(MlirPassManager passManager); MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); -/// Enable mlir-print-ir-after-all. -MLIR_CAPI_EXPORTED void -mlirPassManagerEnableIRPrinting(MlirPassManager passManager); +/// Enable IR printing. +MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( + MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, + bool printModuleScope, bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index a68421b61..1d0e5ce21 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -74,10 +74,17 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "Releases (leaks) the backing pass manager (testing)") .def( "enable_ir_printing", - [](PyPassManager &passManager) { - mlirPassManagerEnableIRPrinting(passManager.get()); + [](PyPassManager &passManager, bool printBeforeAll, + bool printAfterAll, bool printModuleScope, bool printAfterChange, + bool printAfterFailure) { + mlirPassManagerEnableIRPrinting( + passManager.get(), printBeforeAll, printAfterAll, + printModuleScope, printAfterChange, printAfterFailure); }, - "Enable mlir-print-ir-after-all.") + "print_before_all"_a = false, "print_after_all"_a = true, + "print_module_scope"_a = false, "print_after_change"_a = false, + "print_after_failure"_a = false, + "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", [](PyPassManager &passManager, bool enable) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index d242baae9..a6c9fbd08 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -44,8 +44,21 @@ MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, return wrap(unwrap(passManager)->run(unwrap(op))); } -void mlirPassManagerEnableIRPrinting(MlirPassManager passManager) { - return unwrap(passManager)->enableIRPrinting(); +void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, + bool printBeforeAll, bool printAfterAll, + bool printModuleScope, + bool printAfterOnlyOnChange, + bool printAfterOnlyOnFailure) { + auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { + return printBeforeAll; + }; + auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { + return printAfterAll; + }; + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index c072d5e0f..5d115e822 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -16,7 +16,14 @@ class PassManager: def __init__(self, context: Optional[_ir.Context] = None) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... - def enable_ir_printing(self) -> None: ... + def enable_ir_printing( + self, + print_before_all: bool = False, + print_after_all: bool = True, + print_module_scope: bool = False, + print_after_change: bool = False, + print_after_failure: bool = False, + ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... From 08b8078b84791ac2839f6e948093c4a7a69a18b8 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Tue, 24 Sep 2024 08:22:48 +0200 Subject: [PATCH 771/915] [MLIR] Add f4E2M1FN type (#108877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f4E2M1FN` type to mlir. `f4E2M1FN` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 4-bit floating point number with bit layout S1E2M1. Unlike IEEE-754 types, there are no infinity or NaN values. ```c f4E2M1FN - Exponent bias: 1 - Maximum stored exponent value: 3 (binary 11) - Maximum unbiased exponent value: 3 - 1 = 2 - Minimum stored exponent value: 1 (binary 01) - Minimum unbiased exponent value: 1 − 1 = 0 - Has Positive and Negative zero - Doesn't have infinity - Doesn't have NaNs Additional details: - Zeros (+/-): S.00.0 - Max normal number: S.11.1 = ±2^(2) x (1 + 0.5) = ±6.0 - Min normal number: S.01.0 = ±2^(0) = ±1.0 - Min subnormal number: S.00.1 = ±2^(0) x 0.5 = ±0.5 ``` Related PRs: - [PR-95392](https://github.com/llvm/llvm-project/pull/95392) [APFloat] Add APFloat support for FP4 data type - [PR-105573](https://github.com/llvm/llvm-project/pull/105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](https://github.com/llvm/llvm-project/pull/107999) [MLIR] Add f6E2M3FN type --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 22 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 60 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index cc6da482a..6dc25a56b 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -79,6 +79,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat(MlirType type); /// Returns the bitwidth of a floating-point type. MLIR_CAPI_EXPORTED unsigned mlirFloatTypeGetWidth(MlirType type); +/// Returns the typeID of an Float4E2M1FN type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat4E2M1FNTypeGetTypeID(void); + +/// Checks whether the given type is an f4E2M1FN type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat4E2M1FN(MlirType type); + +/// Creates an f4E2M1FN type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx); + /// Returns the typeID of an Float6E2M3FN type. MLIR_CAPI_EXPORTED MlirTypeID mlirFloat6E2M3FNTypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 6b64bc3c9..5a369b5d4 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -124,6 +124,27 @@ class PyFloatType : public PyConcreteType { } }; +/// Floating Point Type subclass - Float4E2M1FNType. +class PyFloat4E2M1FNType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat4E2M1FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float4E2M1FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); + return PyFloat4E2M1FNType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float4_e2m1fn type."); + } +}; + /// Floating Point Type subclass - Float6E2M3FNType. class PyFloat6E2M3FNType : public PyConcreteType { @@ -922,6 +943,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); + PyFloat4E2M1FNType::bind(m); PyFloat6E2M3FNType::bind(m); PyFloat6E3M2FNType::bind(m); PyFloat8E4M3FNType::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index f943bf726..efc1e857a 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -85,6 +85,18 @@ unsigned mlirFloatTypeGetWidth(MlirType type) { return llvm::cast(unwrap(type)).getWidth(); } +MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() { + return wrap(Float4E2M1FNType::getTypeID()); +} + +bool mlirTypeIsAFloat4E2M1FN(MlirType type) { + return unwrap(type).isFloat4E2M1FN(); +} + +MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat4E2M1FN(unwrap(ctx))); +} + MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { return wrap(Float6E2M3FNType::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index ea5c96dcb..4d5b4cef9 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -120,6 +120,7 @@ __all__ = [ "F32Type", "F64Type", "FlatSymbolRefAttr", + "Float4E2M1FNType", "Float6E2M3FNType", "Float6E3M2FNType", "Float8E3M4Type", @@ -1542,6 +1543,19 @@ class FlatSymbolRefAttr(Attribute): Returns the value of the FlatSymbolRef attribute as a string """ +class Float4E2M1FNType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Optional[Context] = None) -> Float4E2M1FNType: + """ + Create a float4_e2m1fn type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class Float6E2M3FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index 4be425f22..5b24a6d52 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -12,6 +12,7 @@ F16Type, F32Type, F64Type, + Float4E2M1FNType, Float6E2M3FNType, Float6E3M2FNType, Float8E3M4Type, @@ -76,6 +77,7 @@ def ui(width): f8E4M3FN = lambda: Float8E4M3FNType.get() f8E4M3B11FNUZ = lambda: Float8E4M3B11FNUZType.get() f8E3M4 = lambda: Float8E3M4Type.get() +f4E2M1FN = lambda: Float4E2M1FNType.get() f6E2M3FN = lambda: Float6E2M3FNType.get() f6E3M2FN = lambda: Float6E3M2FNType.get() From dc2ef3f4a2181c294163b42ab278df502098458f Mon Sep 17 00:00:00 2001 From: Rafael Ubal Date: Thu, 26 Sep 2024 14:09:28 -0400 Subject: [PATCH 772/915] [mlir] Improvements to the 'quant' dialect (#100667) Full revamp of the 'quant' dialect. This is an implementation for the RFC at https://discourse.llvm.org/t/rfc-improvements-in-the-quant-dialect/79942 --- mlir/lib/CAPI/Dialect/Quant.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index 0a7181d8b..c94dbb569 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -8,12 +8,12 @@ #include "mlir-c/Dialect/Quant.h" #include "mlir/CAPI/Registration.h" -#include "mlir/Dialect/Quant/QuantOps.h" -#include "mlir/Dialect/Quant/QuantTypes.h" +#include "mlir/Dialect/Quant/IR/Quant.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" using namespace mlir; -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantizationDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(quant, quant, quant::QuantDialect) //===---------------------------------------------------------------------===// // QuantizedType From 8947b704521c0146a44e2c61f6c8e7995d3c4454 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Tue, 1 Oct 2024 12:49:18 +0100 Subject: [PATCH 773/915] A few tweaks to the MLIR .pyi files (#110488) --- .../python/mlir/_mlir_libs/_mlir/__init__.pyi | 3 +- .../mlir/_mlir_libs/_mlir/dialects/pdl.pyi | 9 +- .../mlir/_mlir_libs/_mlir/dialects/quant.pyi | 7 +- .../_mlir/dialects/transform/__init__.pyi | 5 +- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 377 +++++++++--------- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 5 +- .../mlir/_mlir_libs/_mlirExecutionEngine.pyi | 2 +- 7 files changed, 197 insertions(+), 211 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 93b978c75..42694747e 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -1,9 +1,8 @@ -from typing import List globals: "_Globals" class _Globals: - dialect_search_modules: List[str] + dialect_search_modules: list[str] def _register_dialect_impl(self, dialect_namespace: str, dialect_class: type) -> None: ... def _register_operation_impl(self, operation_name: str, operation_class: type) -> None: ... def append_dialect_search_prefix(self, module_name: str) -> None: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi index 8ec944d19..d12c6839d 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/pdl.pyi @@ -2,7 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional from mlir.ir import Type, Context @@ -26,7 +25,7 @@ class AttributeType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(context: Optional[Context] = None) -> AttributeType: ... + def get(context: Context | None = None) -> AttributeType: ... class OperationType(Type): @@ -34,7 +33,7 @@ class OperationType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(context: Optional[Context] = None) -> OperationType: ... + def get(context: Context | None = None) -> OperationType: ... class RangeType(Type): @@ -53,7 +52,7 @@ class TypeType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(context: Optional[Context] = None) -> TypeType: ... + def get(context: Context | None = None) -> TypeType: ... class ValueType(Type): @@ -61,4 +60,4 @@ class ValueType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(context: Optional[Context] = None) -> ValueType: ... + def get(context: Context | None = None) -> ValueType: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi index c9c66d52b..a10bc693b 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -2,7 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import List from mlir.ir import Type @@ -94,15 +93,15 @@ class UniformQuantizedPerAxisType(QuantizedType): @classmethod def get(cls, flags: int, storage_type: Type, expressed_type: Type, - scales: List[float], zero_points: List[int], quantized_dimension: int, + scales: list[float], zero_points: list[int], quantized_dimension: int, storage_type_min: int, storage_type_max: int): ... @property - def scales(self) -> List[float]: ... + def scales(self) -> list[float]: ... @property - def zero_points(self) -> List[float]: ... + def zero_points(self) -> list[float]: ... @property def quantized_dimension(self) -> int: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi index 2a2954173..a3f1b0910 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/transform/__init__.pyi @@ -2,7 +2,6 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from typing import Optional from mlir.ir import Type, Context @@ -12,7 +11,7 @@ class AnyOpType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(context: Optional[Context] = None) -> AnyOpType: ... + def get(context: Context | None = None) -> AnyOpType: ... class OperationType(Type): @@ -20,7 +19,7 @@ class OperationType(Type): def isinstance(type: Type) -> bool: ... @staticmethod - def get(operation_name: str, context: Optional[Context] = None) -> OperationType: ... + def get(operation_name: str, context: Context | None = None) -> OperationType: ... @property def operation_name(self) -> str: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 4d5b4cef9..41ed84e04 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -44,22 +44,9 @@ from __future__ import annotations import abc import collections +from collections.abc import Callable, Sequence import io -from typing import ( - Any, - Callable, - ClassVar, - Dict, - List, - Optional, - Sequence, - Tuple, - Type as _Type, - TypeVar, - Union, -) - -from typing import overload +from typing import Any, ClassVar, TypeVar, overload __all__ = [ "AffineAddExpr", @@ -210,14 +197,14 @@ class _OperationBase: def get_asm( self, binary: bool = False, - large_elements_limit: Optional[int] = None, + large_elements_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, skip_regions: bool = False, - ) -> Union[io.BytesIO, io.StringIO]: + ) -> io.BytesIO | io.StringIO: """ Gets the assembly form of the operation with all options available. @@ -242,7 +229,7 @@ class _OperationBase: def print( self, state: AsmState, - file: Optional[Any] = None, + file: Any | None = None, binary: bool = False, ) -> None: """ @@ -256,13 +243,13 @@ class _OperationBase: @overload def print( self, - large_elements_limit: Optional[int] = None, + large_elements_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, use_local_scope: bool = False, assume_verified: bool = False, - file: Optional[Any] = None, + file: Any | None = None, binary: bool = False, skip_regions: bool = False, ) -> None: @@ -296,7 +283,7 @@ class _OperationBase: """ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. """ - def write_bytecode(self, file: Any, desired_version: Optional[int] = None) -> None: + def write_bytecode(self, file: Any, desired_version: int | None = None) -> None: """ Write the bytecode form of the operation to a file like object. @@ -325,7 +312,7 @@ class _OperationBase: @property def operands(self) -> OpOperandList: ... @property - def parent(self) -> Optional[_OperationBase]: ... + def parent(self) -> _OperationBase | None: ... @property def regions(self) -> RegionSequence: ... @property @@ -380,13 +367,13 @@ class AffineExpr: """ @staticmethod def get_constant( - value: int, context: Optional[Context] = None + value: int, context: Context | None = None ) -> AffineConstantExpr: """ Gets a constant affine expression with the given value. """ @staticmethod - def get_dim(position: int, context: Optional[Context] = None) -> AffineDimExpr: + def get_dim(position: int, context: Context | None = None) -> AffineDimExpr: """ Gets an affine expression of a dimension at the given position. """ @@ -446,7 +433,7 @@ class AffineExpr: """ @staticmethod def get_symbol( - position: int, context: Optional[Context] = None + position: int, context: Context | None = None ) -> AffineSymbolExpr: """ Gets an affine expression of a symbol at the given position. @@ -489,7 +476,7 @@ class AffineExpr: class Attribute: @staticmethod - def parse(asm: str | bytes, context: Optional[Context] = None) -> Attribute: + def parse(asm: str | bytes, context: Context | None = None) -> Attribute: """ Parses an attribute from an assembly form. Raises an MLIRError on failure. """ @@ -530,7 +517,7 @@ class Attribute: class Type: @staticmethod - def parse(asm: str | bytes, context: Optional[Context] = None) -> Type: + def parse(asm: str | bytes, context: Context | None = None) -> Type: """ Parses the assembly form of a type. @@ -640,7 +627,7 @@ class AffineCeilDivExpr(AffineBinaryExpr): class AffineConstantExpr(AffineExpr): @staticmethod - def get(value: int, context: Optional[Context] = None) -> AffineConstantExpr: ... + def get(value: int, context: Context | None = None) -> AffineConstantExpr: ... @staticmethod def isinstance(other: AffineExpr) -> bool: ... def __init__(self, expr: AffineExpr) -> None: ... @@ -649,7 +636,7 @@ class AffineConstantExpr(AffineExpr): class AffineDimExpr(AffineExpr): @staticmethod - def get(position: int, context: Optional[Context] = None) -> AffineDimExpr: ... + def get(position: int, context: Context | None = None) -> AffineDimExpr: ... @staticmethod def isinstance(other: AffineExpr) -> bool: ... def __init__(self, expr: AffineExpr) -> None: ... @@ -657,7 +644,7 @@ class AffineDimExpr(AffineExpr): def position(self) -> int: ... class AffineExprList: - def __add__(self, arg0: AffineExprList) -> List[AffineExpr]: ... + def __add__(self, arg0: AffineExprList) -> list[AffineExpr]: ... class AffineFloorDivExpr(AffineBinaryExpr): @staticmethod @@ -669,43 +656,43 @@ class AffineFloorDivExpr(AffineBinaryExpr): class AffineMap: @staticmethod def compress_unused_symbols( - arg0: List, arg1: Optional[Context] - ) -> List[AffineMap]: ... + arg0: list, arg1: Context | None + ) -> list[AffineMap]: ... @staticmethod def get( dim_count: int, symbol_count: int, - exprs: List, - context: Optional[Context] = None, + exprs: list, + context: Context | None = None, ) -> AffineMap: """ Gets a map with the given expressions as results. """ @staticmethod - def get_constant(value: int, context: Optional[Context] = None) -> AffineMap: + def get_constant(value: int, context: Context | None = None) -> AffineMap: """ Gets an affine map with a single constant result """ @staticmethod - def get_empty(context: Optional[Context] = None) -> AffineMap: + def get_empty(context: Context | None = None) -> AffineMap: """ Gets an empty affine map. """ @staticmethod - def get_identity(n_dims: int, context: Optional[Context] = None) -> AffineMap: + def get_identity(n_dims: int, context: Context | None = None) -> AffineMap: """ Gets an identity map with the given number of dimensions. """ @staticmethod def get_minor_identity( - n_dims: int, n_results: int, context: Optional[Context] = None + n_dims: int, n_results: int, context: Context | None = None ) -> AffineMap: """ Gets a minor identity map with the given number of dimensions and results. """ @staticmethod def get_permutation( - permutation: List[int], context: Optional[Context] = None + permutation: list[int], context: Context | None = None ) -> AffineMap: """ Gets an affine map that permutes its inputs. @@ -722,7 +709,7 @@ class AffineMap: """ def get_major_submap(self, n_results: int) -> AffineMap: ... def get_minor_submap(self, n_results: int) -> AffineMap: ... - def get_submap(self, result_positions: List[int]) -> AffineMap: ... + def get_submap(self, result_positions: list[int]) -> AffineMap: ... def replace( self, expr: AffineExpr, @@ -748,7 +735,7 @@ class AffineMap: @property def n_symbols(self) -> int: ... @property - def results(self) -> "AffineMapExprList": ... + def results(self) -> AffineMapExprList: ... class AffineMapAttr(Attribute): static_typeid: ClassVar[TypeID] @@ -781,7 +768,7 @@ class AffineMulExpr(AffineBinaryExpr): class AffineSymbolExpr(AffineExpr): @staticmethod - def get(position: int, context: Optional[Context] = None) -> AffineSymbolExpr: ... + def get(position: int, context: Context | None = None) -> AffineSymbolExpr: ... @staticmethod def isinstance(other: AffineExpr) -> bool: ... def __init__(self, expr: AffineExpr) -> None: ... @@ -791,13 +778,13 @@ class AffineSymbolExpr(AffineExpr): class ArrayAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(attributes: List, context: Optional[Context] = None) -> ArrayAttr: + def get(attributes: list, context: Context | None = None) -> ArrayAttr: """ Gets a uniqued Array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> ArrayAttr: ... + def __add__(self, arg0: list) -> ArrayAttr: ... def __getitem__(self, arg0: int) -> Attribute: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -835,7 +822,7 @@ class AttrBuilder: class BF16Type(Type): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> BF16Type: + def get(context: Context | None = None) -> BF16Type: """ Create a bf16 type. """ @@ -849,8 +836,8 @@ class Block: @staticmethod def create_at_start( parent: Region, - arg_types: List[Type], - arg_locs: Optional[Sequence] = None, + arg_types: list[Type], + arg_locs: Sequence | None = None, ) -> Block: """ Creates and returns a new Block at the beginning of the given region (with given argument types and locations). @@ -876,11 +863,11 @@ class Block: """ Append this block to a region, transferring ownership if necessary """ - def create_after(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + def create_after(self, *args, arg_locs: Sequence | None = None) -> Block: """ Creates and returns a new Block after this block (with given argument types and locations). """ - def create_before(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + def create_before(self, *args, arg_locs: Sequence | None = None) -> Block: """ Creates and returns a new Block before this block (with given argument types and locations). """ @@ -924,9 +911,9 @@ class BlockArgumentList: @overload def __getitem__(self, arg0: slice) -> BlockArgumentList: ... def __len__(self) -> int: ... - def __add__(self, arg0: BlockArgumentList) -> List[BlockArgument]: ... + def __add__(self, arg0: BlockArgumentList) -> list[BlockArgument]: ... @property - def types(self) -> List[Type]: ... + def types(self) -> list[Type]: ... class BlockIterator: def __iter__(self) -> BlockIterator: ... @@ -936,7 +923,7 @@ class BlockList: def __getitem__(self, arg0: int) -> Block: ... def __iter__(self) -> BlockIterator: ... def __len__(self) -> int: ... - def append(self, *args, arg_locs: Optional[Sequence] = None) -> Block: + def append(self, *args, arg_locs: Sequence | None = None) -> Block: """ Appends a new block, with argument types as positional args. @@ -946,7 +933,7 @@ class BlockList: class BoolAttr(Attribute): @staticmethod - def get(value: bool, context: Optional[Context] = None) -> BoolAttr: + def get(value: bool, context: Context | None = None) -> BoolAttr: """ Gets an uniqued bool attribute """ @@ -1000,7 +987,7 @@ class Context: def _get_context_again(self) -> Context: ... def _get_live_module_count(self) -> int: ... def _get_live_operation_count(self) -> int: ... - def _get_live_operation_objects(self) -> List[Operation]: ... + def _get_live_operation_objects(self) -> list[Operation]: ... def append_dialect_registry(self, registry: DialectRegistry) -> None: ... def attach_diagnostic_handler( self, callback: Callable[[Diagnostic], bool] @@ -1031,14 +1018,14 @@ class Context: class DenseBoolArrayAttr(Attribute): @staticmethod def get( - values: Sequence[bool], context: Optional[Context] = None + values: Sequence[bool], context: Context | None = None ) -> DenseBoolArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseBoolArrayAttr: ... + def __add__(self, arg0: list) -> DenseBoolArrayAttr: ... def __getitem__(self, arg0: int) -> bool: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1061,9 +1048,9 @@ class DenseElementsAttr(Attribute): def get( array: Buffer, signless: bool = True, - type: Optional[Type] = None, - shape: Optional[List[int]] = None, - context: Optional[Context] = None, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, ) -> DenseElementsAttr: """ Gets a DenseElementsAttr from a Python buffer or array. @@ -1128,14 +1115,14 @@ class DenseElementsAttr(Attribute): class DenseF32ArrayAttr(Attribute): @staticmethod def get( - values: Sequence[float], context: Optional[Context] = None + values: Sequence[float], context: Context | None = None ) -> DenseF32ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseF32ArrayAttr: ... + def __add__(self, arg0: list) -> DenseF32ArrayAttr: ... def __getitem__(self, arg0: int) -> float: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1156,14 +1143,14 @@ class DenseF32ArrayIterator: class DenseF64ArrayAttr(Attribute): @staticmethod def get( - values: Sequence[float], context: Optional[Context] = None + values: Sequence[float], context: Context | None = None ) -> DenseF64ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseF64ArrayAttr: ... + def __add__(self, arg0: list) -> DenseF64ArrayAttr: ... def __getitem__(self, arg0: int) -> float: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1186,9 +1173,9 @@ class DenseFPElementsAttr(DenseElementsAttr): def get( array: Buffer, signless: bool = True, - type: Optional[Type] = None, - shape: Optional[List[int]] = None, - context: Optional[Context] = None, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, ) -> DenseFPElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... @@ -1203,13 +1190,13 @@ class DenseFPElementsAttr(DenseElementsAttr): class DenseI16ArrayAttr(Attribute): @staticmethod - def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI16ArrayAttr: + def get(values: Sequence[int], context: Context | None = None) -> DenseI16ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseI16ArrayAttr: ... + def __add__(self, arg0: list) -> DenseI16ArrayAttr: ... def __getitem__(self, arg0: int) -> int: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1229,13 +1216,13 @@ class DenseI16ArrayIterator: class DenseI32ArrayAttr(Attribute): @staticmethod - def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI32ArrayAttr: + def get(values: Sequence[int], context: Context | None = None) -> DenseI32ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseI32ArrayAttr: ... + def __add__(self, arg0: list) -> DenseI32ArrayAttr: ... def __getitem__(self, arg0: int) -> int: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1255,13 +1242,13 @@ class DenseI32ArrayIterator: class DenseI64ArrayAttr(Attribute): @staticmethod - def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI64ArrayAttr: + def get(values: Sequence[int], context: Context | None = None) -> DenseI64ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseI64ArrayAttr: ... + def __add__(self, arg0: list) -> DenseI64ArrayAttr: ... def __getitem__(self, arg0: int) -> int: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1281,13 +1268,13 @@ class DenseI64ArrayIterator: class DenseI8ArrayAttr(Attribute): @staticmethod - def get(values: Sequence[int], context: Optional[Context] = None) -> DenseI8ArrayAttr: + def get(values: Sequence[int], context: Context | None = None) -> DenseI8ArrayAttr: """ Gets a uniqued dense array attribute """ @staticmethod def isinstance(other: Attribute) -> bool: ... - def __add__(self, arg0: List) -> DenseI8ArrayAttr: ... + def __add__(self, arg0: list) -> DenseI8ArrayAttr: ... def __getitem__(self, arg0: int) -> int: ... def __init__(self, cast_from_attr: Attribute) -> None: ... def __iter__( @@ -1310,9 +1297,9 @@ class DenseIntElementsAttr(DenseElementsAttr): def get( array: Buffer, signless: bool = True, - type: Optional[Type] = None, - shape: Optional[List[int]] = None, - context: Optional[Context] = None, + type: Type | None = None, + shape: list[int] | None = None, + context: Context | None = None, ) -> DenseIntElementsAttr: ... @staticmethod def isinstance(other: Attribute) -> bool: ... @@ -1331,9 +1318,9 @@ class DenseResourceElementsAttr(Attribute): array: Buffer, name: str, type: Type, - alignment: Optional[int] = None, + alignment: int | None = None, is_mutable: bool = False, - context: Optional[Context] = None, + context: Context | None = None, ) -> DenseResourceElementsAttr: """ Gets a DenseResourceElementsAttr from a Python buffer or array. @@ -1376,7 +1363,7 @@ class Diagnostic: @property def message(self) -> str: ... @property - def notes(self) -> Tuple[Diagnostic]: ... + def notes(self) -> tuple[Diagnostic]: ... @property def severity(self) -> DiagnosticSeverity: ... @@ -1396,7 +1383,7 @@ class DiagnosticInfo: @property def message(self) -> str: ... @property - def notes(self) -> List[DiagnosticInfo]: ... + def notes(self) -> list[DiagnosticInfo]: ... @property def severity(self) -> DiagnosticSeverity: ... @@ -1418,7 +1405,7 @@ class DiagnosticSeverity: REMARK: ClassVar[DiagnosticSeverity] # value = WARNING: ClassVar[DiagnosticSeverity] # value = __members__: ClassVar[ - Dict[str, DiagnosticSeverity] + dict[str, DiagnosticSeverity] ] # value = {'ERROR': , 'WARNING': , 'NOTE': , 'REMARK': } def __eq__(self, other: Any) -> bool: ... def __getstate__(self) -> int: ... @@ -1455,7 +1442,7 @@ class Dialects: class DictAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(value: Dict = {}, context: Optional[Context] = None) -> DictAttr: + def get(value: dict = {}, context: Context | None = None) -> DictAttr: """ Gets an uniqued Dict attribute """ @@ -1486,7 +1473,7 @@ class FloatType(Type): class F16Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> F16Type: + def get(context: Context | None = None) -> F16Type: """ Create a f16 type. """ @@ -1499,7 +1486,7 @@ class F16Type(FloatType): class F32Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> F32Type: + def get(context: Context | None = None) -> F32Type: """ Create a f32 type. """ @@ -1512,7 +1499,7 @@ class F32Type(FloatType): class F64Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> F64Type: + def get(context: Context | None = None) -> F64Type: """ Create a f64 type. """ @@ -1524,7 +1511,7 @@ class F64Type(FloatType): class FlatSymbolRefAttr(Attribute): @staticmethod - def get(value: str, context: Optional[Context] = None) -> FlatSymbolRefAttr: + def get(value: str, context: Context | None = None) -> FlatSymbolRefAttr: """ Gets a uniqued FlatSymbolRef attribute """ @@ -1546,7 +1533,7 @@ class FlatSymbolRefAttr(Attribute): class Float4E2M1FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float4E2M1FNType: + def get(context: Context | None = None) -> Float4E2M1FNType: """ Create a float4_e2m1fn type. """ @@ -1559,7 +1546,7 @@ class Float4E2M1FNType(FloatType): class Float6E2M3FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float6E2M3FNType: + def get(context: Context | None = None) -> Float6E2M3FNType: """ Create a float6_e2m3fn type. """ @@ -1572,7 +1559,7 @@ class Float6E2M3FNType(FloatType): class Float6E3M2FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float6E3M2FNType: + def get(context: Context | None = None) -> Float6E3M2FNType: """ Create a float6_e3m2fn type. """ @@ -1585,7 +1572,7 @@ class Float6E3M2FNType(FloatType): class Float8E3M4Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E3M4Type: + def get(context: Context | None = None) -> Float8E3M4Type: """ Create a float8_e3m4 type. """ @@ -1598,7 +1585,7 @@ class Float8E3M4Type(FloatType): class Float8E4M3B11FNUZType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E4M3B11FNUZType: + def get(context: Context | None = None) -> Float8E4M3B11FNUZType: """ Create a float8_e4m3b11fnuz type. """ @@ -1611,7 +1598,7 @@ class Float8E4M3B11FNUZType(FloatType): class Float8E4M3FNType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E4M3FNType: + def get(context: Context | None = None) -> Float8E4M3FNType: """ Create a float8_e4m3fn type. """ @@ -1624,7 +1611,7 @@ class Float8E4M3FNType(FloatType): class Float8E4M3FNUZType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E4M3FNUZType: + def get(context: Context | None = None) -> Float8E4M3FNUZType: """ Create a float8_e4m3fnuz type. """ @@ -1637,7 +1624,7 @@ class Float8E4M3FNUZType(FloatType): class Float8E4M3Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E4M3Type: + def get(context: Context | None = None) -> Float8E4M3Type: """ Create a float8_e4m3 type. """ @@ -1650,7 +1637,7 @@ class Float8E4M3Type(FloatType): class Float8E5M2FNUZType(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E5M2FNUZType: + def get(context: Context | None = None) -> Float8E5M2FNUZType: """ Create a float8_e5m2fnuz type. """ @@ -1663,7 +1650,7 @@ class Float8E5M2FNUZType(FloatType): class Float8E5M2Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> Float8E5M2Type: + def get(context: Context | None = None) -> Float8E5M2Type: """ Create a float8_e5m2 type. """ @@ -1676,17 +1663,17 @@ class Float8E5M2Type(FloatType): class FloatAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(type: Type, value: float, loc: Optional[Location] = None) -> FloatAttr: + def get(type: Type, value: float, loc: Location | None = None) -> FloatAttr: """ Gets an uniqued float point attribute associated to a type """ @staticmethod - def get_f32(value: float, context: Optional[Context] = None) -> FloatAttr: + def get_f32(value: float, context: Context | None = None) -> FloatAttr: """ Gets an uniqued float point attribute associated to a f32 type """ @staticmethod - def get_f64(value: float, context: Optional[Context] = None) -> FloatAttr: + def get_f64(value: float, context: Context | None = None) -> FloatAttr: """ Gets an uniqued float point attribute associated to a f64 type """ @@ -1710,7 +1697,7 @@ class FloatAttr(Attribute): class FloatTF32Type(FloatType): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> FloatTF32Type: + def get(context: Context | None = None) -> FloatTF32Type: """ Create a tf32 type. """ @@ -1724,7 +1711,7 @@ class FunctionType(Type): static_typeid: ClassVar[TypeID] @staticmethod def get( - inputs: List[Type], results: List[Type], context: Optional[Context] = None + inputs: list[Type], results: list[Type], context: Context | None = None ) -> FunctionType: """ Gets a FunctionType from a List of input and result types @@ -1733,12 +1720,12 @@ class FunctionType(Type): def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... @property - def inputs(self) -> List: + def inputs(self) -> list: """ Returns the List of input types in the FunctionType. """ @property - def results(self) -> List: + def results(self) -> list: """ Returns the List of result types in the FunctionType. """ @@ -1748,7 +1735,7 @@ class FunctionType(Type): class IndexType(Type): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> IndexType: + def get(context: Context | None = None) -> IndexType: """ Create a index type. """ @@ -1759,7 +1746,7 @@ class IndexType(Type): def typeid(self) -> TypeID: ... class InferShapedTypeOpInterface: - def __init__(self, object: object, context: Optional[Context] = None) -> None: + def __init__(self, object: object, context: Context | None = None) -> None: """ Creates an interface from a given operation/opview object or from a subclass of OpView. Raises ValueError if the operation does not implement the @@ -1767,13 +1754,13 @@ class InferShapedTypeOpInterface: """ def inferReturnTypeComponents( self, - operands: Optional[List] = None, - attributes: Optional[Attribute] = None, + operands: list | None = None, + attributes: Attribute | None = None, properties=None, - regions: Optional[List[Region]] = None, - context: Optional[Context] = None, - loc: Optional[Location] = None, - ) -> List[ShapedTypeComponents]: + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[ShapedTypeComponents]: """ Given the arguments required to build an operation, attempts to infer its return shaped type components. Raises ValueError on failure. @@ -1791,7 +1778,7 @@ class InferShapedTypeOpInterface: """ class InferTypeOpInterface: - def __init__(self, object: object, context: Optional[Context] = None) -> None: + def __init__(self, object: object, context: Context | None = None) -> None: """ Creates an interface from a given operation/opview object or from a subclass of OpView. Raises ValueError if the operation does not implement the @@ -1799,13 +1786,13 @@ class InferTypeOpInterface: """ def inferReturnTypes( self, - operands: Optional[List] = None, - attributes: Optional[Attribute] = None, + operands: list | None = None, + attributes: Attribute | None = None, properties=None, - regions: Optional[List[Region]] = None, - context: Optional[Context] = None, - loc: Optional[Location] = None, - ) -> List[Type]: + regions: list[Region] | None = None, + context: Context | None = None, + loc: Location | None = None, + ) -> list[Type]: """ Given the arguments required to build an operation, attempts to infer its return types. Raises ValueError on failure. @@ -1856,7 +1843,7 @@ class InsertionPoint: Returns the block that this InsertionPoint points to. """ @property - def ref_operation(self) -> Optional[_OperationBase]: + def ref_operation(self) -> _OperationBase | None: """ The reference operation before which new operations are inserted, or None if the insertion point is at the end of the block """ @@ -1890,13 +1877,13 @@ class IntegerSet: def get( num_dims: int, num_symbols: int, - exprs: List, - eq_flags: List[bool], - context: Optional[Context] = None, + exprs: list, + eq_flags: list[bool], + context: Context | None = None, ) -> IntegerSet: ... @staticmethod def get_empty( - num_dims: int, num_symbols: int, context: Optional[Context] = None + num_dims: int, num_symbols: int, context: Context | None = None ) -> IntegerSet: ... def _CAPICreate(self) -> IntegerSet: ... @overload @@ -1910,8 +1897,8 @@ class IntegerSet: """ def get_replaced( self, - dim_exprs: List, - symbol_exprs: List, + dim_exprs: list, + symbol_exprs: list, num_result_dims: int, num_result_symbols: int, ) -> IntegerSet: ... @@ -1958,7 +1945,7 @@ class IntegerSetConstraint: class IntegerSetConstraintList: def __init__(self, *args, **kwargs) -> None: ... - def __add__(self, arg0: IntegerSetConstraintList) -> List[IntegerSetConstraint]: ... + def __add__(self, arg0: IntegerSetConstraintList) -> list[IntegerSetConstraint]: ... @overload def __getitem__(self, arg0: int) -> IntegerSetConstraint: ... @overload @@ -1968,17 +1955,17 @@ class IntegerSetConstraintList: class IntegerType(Type): static_typeid: ClassVar[TypeID] @staticmethod - def get_signed(width: int, context: Optional[Context] = None) -> IntegerType: + def get_signed(width: int, context: Context | None = None) -> IntegerType: """ Create a signed integer type """ @staticmethod - def get_signless(width: int, context: Optional[Context] = None) -> IntegerType: + def get_signless(width: int, context: Context | None = None) -> IntegerType: """ Create a signless integer type """ @staticmethod - def get_unsigned(width: int, context: Optional[Context] = None) -> IntegerType: + def get_unsigned(width: int, context: Context | None = None) -> IntegerType: """ Create an unsigned integer type """ @@ -2013,28 +2000,28 @@ class Location: __hash__: ClassVar[None] = None @staticmethod def callsite( - callee: Location, frames: Sequence[Location], context: Optional[Context] = None + callee: Location, frames: Sequence[Location], context: Context | None = None ) -> Location: """ Gets a Location representing a caller and callsite """ @staticmethod def file( - filename: str, line: int, col: int, context: Optional[Context] = None + filename: str, line: int, col: int, context: Context | None = None ) -> Location: """ Gets a Location representing a file, line and column """ @staticmethod - def from_attr(attribute: Attribute, context: Optional[Context] = None) -> Location: + def from_attr(attribute: Attribute, context: Context | None = None) -> Location: """ Gets a Location from a LocationAttr """ @staticmethod def fused( locations: Sequence[Location], - metadata: Optional[Attribute] = None, - context: Optional[Context] = None, + metadata: Attribute | None = None, + context: Context | None = None, ) -> Location: """ Gets a Location representing a fused location with optional metadata @@ -2042,14 +2029,14 @@ class Location: @staticmethod def name( name: str, - childLoc: Optional[Location] = None, - context: Optional[Context] = None, + childLoc: Location | None = None, + context: Context | None = None, ) -> Location: """ Gets a Location representing a named location with optional child location """ @staticmethod - def unknown(context: Optional[Context] = None) -> Location: + def unknown(context: Context | None = None) -> Location: """ Gets a Location representing an unknown location """ @@ -2081,11 +2068,11 @@ class MemRefType(ShapedType): static_typeid: ClassVar[TypeID] @staticmethod def get( - shape: List[int], + shape: list[int], element_type: Type, layout: Attribute = None, memory_space: Attribute = None, - loc: Optional[Location] = None, + loc: Location | None = None, ) -> MemRefType: """ Create a memref type @@ -2104,21 +2091,25 @@ class MemRefType(ShapedType): The layout of the MemRef type. """ @property - def memory_space(self) -> Optional[Attribute]: + def memory_space(self) -> Attribute | None: """ Returns the memory space of the given MemRef type. """ @property def typeid(self) -> TypeID: ... + def get_strides_and_offset(self) -> tuple[list[int], list[int]]: + """ + The strides and offset of the MemRef type. + """ class Module: @staticmethod - def create(loc: Optional[Location] = None) -> Module: + def create(loc: Location | None = None) -> Module: """ Creates an empty module """ @staticmethod - def parse(asm: str | bytes, context: Optional[Context] = None) -> Module: + def parse(asm: str | bytes, context: Context | None = None) -> Module: """ Parses a module's assembly format from a string. @@ -2159,7 +2150,7 @@ class Module: class MLIRError(Exception): def __init__( - self, message: str, error_diagnostics: List[DiagnosticInfo] + self, message: str, error_diagnostics: list[DiagnosticInfo] ) -> None: ... class NamedAttribute: @@ -2177,7 +2168,7 @@ class NamedAttribute: class NoneType(Type): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> NoneType: + def get(context: Context | None = None) -> NoneType: """ Create a none type. """ @@ -2208,7 +2199,7 @@ class OpOperandIterator: def __next__(self) -> OpOperand: ... class OpOperandList: - def __add__(self, arg0: OpOperandList) -> List[Value]: ... + def __add__(self, arg0: OpOperandList) -> list[Value]: ... @overload def __getitem__(self, arg0: int) -> Value: ... @overload @@ -2228,7 +2219,7 @@ class OpResult(Value): def result_number(self) -> int: ... class OpResultList: - def __add__(self, arg0: OpResultList) -> List[OpResult]: ... + def __add__(self, arg0: OpResultList) -> list[OpResult]: ... @overload def __getitem__(self, arg0: int) -> OpResult: ... @overload @@ -2237,10 +2228,10 @@ class OpResultList: @property def owner(self) -> _OperationBase: ... @property - def types(self) -> List[Type]: ... + def types(self) -> list[Type]: ... class OpSuccessors: - def __add__(self, arg0: OpSuccessors) -> List[Block]: ... + def __add__(self, arg0: OpSuccessors) -> list[Block]: ... @overload def __getitem__(self, arg0: int) -> Block: ... @overload @@ -2255,25 +2246,25 @@ class OpView(_OperationBase): def __init__(self, operation: _OperationBase) -> None: ... @classmethod def build_generic( - cls: _Type[_TOperation], - results: Optional[Sequence[Type]] = None, - operands: Optional[Sequence[Value]] = None, - attributes: Optional[Dict[str, Attribute]] = None, - successors: Optional[Sequence[Block]] = None, - regions: Optional[int] = None, - loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None, + cls: type[_TOperation], + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, + regions: int | None = None, + loc: Location | None = None, + ip: InsertionPoint | None = None, ) -> _TOperation: """ Builds a specific, generated OpView based on class level attributes. """ @classmethod def parse( - cls: _Type[_TOperation], + cls: type[_TOperation], source: str | bytes, *, source_name: str = "", - context: Optional[Context] = None, + context: Context | None = None, ) -> _TOperation: """ Parses a specific, generated OpView based on class level attributes @@ -2296,7 +2287,7 @@ class OpaqueAttr(Attribute): dialect_namespace: str, buffer: Buffer, type: Type, - context: Optional[Context] = None, + context: Context | None = None, ) -> OpaqueAttr: """ Gets an Opaque attribute. @@ -2323,7 +2314,7 @@ class OpaqueType(Type): static_typeid: ClassVar[TypeID] @staticmethod def get( - dialect_namespace: str, buffer: str, context: Optional[Context] = None + dialect_namespace: str, buffer: str, context: Context | None = None ) -> OpaqueType: """ Create an unregistered (opaque) dialect type. @@ -2349,13 +2340,13 @@ class Operation(_OperationBase): @staticmethod def create( name: str, - results: Optional[Sequence[Type]] = None, - operands: Optional[Sequence[Value]] = None, - attributes: Optional[Dict[str, Attribute]] = None, - successors: Optional[Sequence[Block]] = None, + results: Sequence[Type] | None = None, + operands: Sequence[Value] | None = None, + attributes: dict[str, Attribute] | None = None, + successors: Sequence[Block] | None = None, regions: int = 0, - loc: Optional[Location] = None, - ip: Optional[InsertionPoint] = None, + loc: Location | None = None, + ip: InsertionPoint | None = None, infer_type: bool = False, ) -> Operation: """ @@ -2378,7 +2369,7 @@ class Operation(_OperationBase): """ @staticmethod def parse( - source: str | bytes, *, source_name: str = "", context: Optional[Context] = None + source: str | bytes, *, source_name: str = "", context: Context | None = None ) -> Operation: """ Parses an operation. Supports both text assembly format and binary bytecode format. @@ -2409,10 +2400,10 @@ class RankedTensorType(ShapedType): static_typeid: ClassVar[TypeID] @staticmethod def get( - shape: List[int], + shape: list[int], element_type: Type, - encoding: Optional[Attribute] = None, - loc: Optional[Location] = None, + encoding: Attribute | None = None, + loc: Location | None = None, ) -> RankedTensorType: """ Create a ranked tensor type @@ -2421,7 +2412,7 @@ class RankedTensorType(ShapedType): def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... @property - def encoding(self) -> Optional[Attribute]: ... + def encoding(self) -> Attribute | None: ... @property def typeid(self) -> TypeID: ... @@ -2507,7 +2498,7 @@ class ShapedType(Type): Returns the rank of the given ranked shaped type. """ @property - def shape(self) -> List[int]: + def shape(self) -> list[int]: """ Returns the shape of the ranked shaped type as a List of integers. """ @@ -2525,14 +2516,14 @@ class ShapedTypeComponents: """ @staticmethod @overload - def get(shape: List, element_type: Type) -> ShapedTypeComponents: + def get(shape: list, element_type: Type) -> ShapedTypeComponents: """ Create a ranked shaped type components object. """ @staticmethod @overload def get( - shape: List, element_type: Type, attribute: Attribute + shape: list, element_type: Type, attribute: Attribute ) -> ShapedTypeComponents: """ Create a ranked shaped type components object with attribute. @@ -2553,7 +2544,7 @@ class ShapedTypeComponents: Returns the rank of the given ranked shaped type components. If the shaped type components does not have a rank, None is returned. """ @property - def shape(self) -> List[int]: + def shape(self) -> list[int]: """ Returns the shape of the ranked shaped type components as a List of integers. Returns none if the shaped type component does not have a rank. """ @@ -2562,14 +2553,14 @@ class StridedLayoutAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod def get( - offset: int, strides: List[int], context: Optional[Context] = None + offset: int, strides: list[int], context: Context | None = None ) -> StridedLayoutAttr: """ Gets a strided layout attribute. """ @staticmethod def get_fully_dynamic( - rank: int, context: Optional[Context] = None + rank: int, context: Context | None = None ) -> StridedLayoutAttr: """ Gets a strided layout attribute with dynamic offset and strides of a given rank. @@ -2583,7 +2574,7 @@ class StridedLayoutAttr(Attribute): Returns the value of the float point attribute """ @property - def strides(self) -> List[int]: + def strides(self) -> list[int]: """ Returns the value of the float point attribute """ @@ -2595,7 +2586,7 @@ class StridedLayoutAttr(Attribute): class StringAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(value: str | bytes, context: Optional[Context] = None) -> StringAttr: + def get(value: str | bytes, context: Context | None = None) -> StringAttr: """ Gets a uniqued string attribute """ @@ -2624,7 +2615,7 @@ class StringAttr(Attribute): class SymbolRefAttr(Attribute): @staticmethod - def get(symbols: List[str], context: Optional[Context] = None) -> Attribute: + def get(symbols: list[str], context: Context | None = None) -> Attribute: """ Gets a uniqued SymbolRef attribute from a List of symbol names """ @@ -2638,7 +2629,7 @@ class SymbolRefAttr(Attribute): @property def typeid(self) -> TypeID: ... @property - def value(self) -> List[str]: + def value(self) -> list[str]: """ Returns the value of the SymbolRef attribute as a List[str] """ @@ -2672,7 +2663,7 @@ class SymbolTable: class TupleType(Type): static_typeid: ClassVar[TypeID] @staticmethod - def get_tuple(elements: List[Type], context: Optional[Context] = None) -> TupleType: + def get_tuple(elements: list[Type], context: Context | None = None) -> TupleType: """ Create a Tuple type """ @@ -2694,7 +2685,7 @@ class TupleType(Type): class TypeAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(value: Type, context: Optional[Context] = None) -> TypeAttr: + def get(value: Type, context: Context | None = None) -> TypeAttr: """ Gets a uniqued Type attribute """ @@ -2721,7 +2712,7 @@ class TypeID: class UnitAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod - def get(context: Optional[Context] = None) -> UnitAttr: + def get(context: Context | None = None) -> UnitAttr: """ Create a Unit attribute. """ @@ -2737,7 +2728,7 @@ class UnrankedMemRefType(ShapedType): static_typeid: ClassVar[TypeID] @staticmethod def get( - element_type: Type, memory_space: Attribute, loc: Optional[Location] = None + element_type: Type, memory_space: Attribute, loc: Location | None = None ) -> UnrankedMemRefType: """ Create a unranked memref type @@ -2746,7 +2737,7 @@ class UnrankedMemRefType(ShapedType): def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... @property - def memory_space(self) -> Optional[Attribute]: + def memory_space(self) -> Attribute | None: """ Returns the memory space of the given Unranked MemRef type. """ @@ -2756,7 +2747,7 @@ class UnrankedMemRefType(ShapedType): class UnrankedTensorType(ShapedType): static_typeid: ClassVar[TypeID] @staticmethod - def get(element_type: Type, loc: Optional[Location] = None) -> UnrankedTensorType: + def get(element_type: Type, loc: Location | None = None) -> UnrankedTensorType: """ Create a unranked tensor type """ @@ -2770,12 +2761,12 @@ class VectorType(ShapedType): static_typeid: ClassVar[TypeID] @staticmethod def get( - shape: List[int], + shape: list[int], element_type: Type, *, - scalable: Optional[List] = None, - scalable_dims: Optional[List[int]] = None, - loc: Optional[Location] = None, + scalable: list | None = None, + scalable_dims: list[int] | None = None, + loc: Location | None = None, ) -> VectorType: """ Create a vector type @@ -2786,7 +2777,7 @@ class VectorType(ShapedType): @property def scalable(self) -> bool: ... @property - def scalable_dims(self) -> List[bool]: ... + def scalable_dims(self) -> list[bool]: ... @property def typeid(self) -> TypeID: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 5d115e822..229979ae3 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -4,7 +4,6 @@ # * Relative imports for cross-module references. # * Add __all__ -from typing import Any, Optional from . import ir as _ir @@ -13,7 +12,7 @@ __all__ = [ ] class PassManager: - def __init__(self, context: Optional[_ir.Context] = None) -> None: ... + def __init__(self, context: _ir.Context | None = None) -> None: ... def _CAPICreate(self) -> object: ... def _testing_release(self) -> None: ... def enable_ir_printing( @@ -26,7 +25,7 @@ class PassManager: ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod - def parse(pipeline: str, context: Optional[_ir.Context] = None) -> PassManager: ... + def parse(pipeline: str, context: _ir.Context | None = None) -> PassManager: ... def run(self, module: _ir._OperationBase) -> None: ... @property def _CAPIPtr(self) -> object: ... diff --git a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi index 893dab8a4..58d453d2b 100644 --- a/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlirExecutionEngine.pyi @@ -4,7 +4,7 @@ # * Relative imports for cross-module references. # * Add __all__ -from typing import List, Sequence +from collections.abc import Sequence from ._mlir import ir as _ir From 37a5fdb6ed3a2a854bce6950188597f9e125ea6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <8431159+mtsokol@users.noreply.github.com> Date: Tue, 1 Oct 2024 22:48:00 +0200 Subject: [PATCH 774/915] [MLIR][Python] Add `encoding` argument to `tensor.empty` Python function (#110656) Hi @xurui1995 @makslevental, I think in https://github.com/llvm/llvm-project/pull/103087 there's unintended regression where user can no longer create sparse tensors with `tensor.empty`. Previously I could pass: ```python out = tensor.empty(tensor_type, []) ``` where `tensor_type` contained `shape`, `dtype`, and `encoding`. With the latest ```python tensor.empty(sizes: Sequence[Union[int, Value]], element_type: Type, *, loc=None, ip=None) ``` it's no longer possible. I propose to add `encoding` argument which is passed to `RankedTensorType.get(static_sizes, element_type, encoding)` (I updated one of the tests to check it). --- mlir/python/mlir/dialects/tensor.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/tensor.py b/mlir/python/mlir/dialects/tensor.py index 0b30d1020..146b5f85d 100644 --- a/mlir/python/mlir/dialects/tensor.py +++ b/mlir/python/mlir/dialects/tensor.py @@ -1,6 +1,7 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import Optional from ._tensor_ops_gen import * from ._tensor_ops_gen import _Dialect @@ -25,6 +26,7 @@ def __init__( sizes: Sequence[Union[int, Value]], element_type: Type, *, + encoding: Optional[Attribute] = None, loc=None, ip=None, ): @@ -40,7 +42,7 @@ def __init__( else: static_sizes.append(ShapedType.get_dynamic_size()) dynamic_sizes.append(s) - result_type = RankedTensorType.get(static_sizes, element_type) + result_type = RankedTensorType.get(static_sizes, element_type, encoding) super().__init__(result_type, dynamic_sizes, loc=loc, ip=ip) @@ -48,11 +50,14 @@ def empty( sizes: Sequence[Union[int, Value]], element_type: Type, *, + encoding: Optional[Attribute] = None, loc=None, ip=None, ) -> _ods_cext.ir.Value: return _get_op_result_or_op_results( - EmptyOp(sizes=sizes, element_type=element_type, loc=loc, ip=ip) + EmptyOp( + sizes=sizes, element_type=element_type, encoding=encoding, loc=loc, ip=ip + ) ) From 40dfa79483a56e584272a42f1bedc9b14ef779c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= <8431159+mtsokol@users.noreply.github.com> Date: Wed, 2 Oct 2024 18:07:55 +0200 Subject: [PATCH 775/915] [MLIR][sparse] Add `soa` property to `sparse_tensor` Python bindings (#109135) --- mlir/include/mlir-c/Dialect/SparseTensor.h | 1 + mlir/lib/Bindings/Python/DialectSparseTensor.cpp | 3 ++- mlir/lib/CAPI/Dialect/SparseTensor.cpp | 4 +++- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/SparseTensor.h b/mlir/include/mlir-c/Dialect/SparseTensor.h index 125469f57..c816c1b58 100644 --- a/mlir/include/mlir-c/Dialect/SparseTensor.h +++ b/mlir/include/mlir-c/Dialect/SparseTensor.h @@ -39,6 +39,7 @@ enum MlirSparseTensorLevelFormat { enum MlirSparseTensorLevelPropertyNondefault { MLIR_SPARSE_PROPERTY_NON_UNIQUE = 0x0001, MLIR_SPARSE_PROPERTY_NON_ORDERED = 0x0002, + MLIR_SPARSE_PROPERTY_SOA = 0x0004, }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index 584981cfe..a730bf500 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -33,7 +33,8 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { py::enum_(m, "LevelProperty", py::module_local()) .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED) - .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE); + .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE) + .value("soa", MLIR_SPARSE_PROPERTY_SOA); mlir_attribute_subclass(m, "EncodingAttr", mlirAttributeIsASparseTensorEncodingAttr) diff --git a/mlir/lib/CAPI/Dialect/SparseTensor.cpp b/mlir/lib/CAPI/Dialect/SparseTensor.cpp index f2a0ab33c..cf25b5263 100644 --- a/mlir/lib/CAPI/Dialect/SparseTensor.cpp +++ b/mlir/lib/CAPI/Dialect/SparseTensor.cpp @@ -36,7 +36,9 @@ static_assert( static_assert(static_cast(MLIR_SPARSE_PROPERTY_NON_ORDERED) == static_cast(LevelPropNonDefault::Nonordered) && static_cast(MLIR_SPARSE_PROPERTY_NON_UNIQUE) == - static_cast(LevelPropNonDefault::Nonunique), + static_cast(LevelPropNonDefault::Nonunique) && + static_cast(MLIR_SPARSE_PROPERTY_SOA) == + static_cast(LevelPropNonDefault::SoA), "MlirSparseTensorLevelProperty (C-API) and " "LevelPropertyNondefault (C++) mismatch"); From 94445714cb48308bdbf56d3017d62a253b826697 Mon Sep 17 00:00:00 2001 From: Sergey Kozub Date: Fri, 4 Oct 2024 09:23:12 +0200 Subject: [PATCH 776/915] [MLIR] Add f8E8M0FNU type (#111028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `f8E8M0FNU` type to MLIR. `f8E8M0FNU` type is proposed in [OpenCompute MX Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). It defines a 8-bit floating point number with bit layout S0E8M0. Unlike IEEE-754 types, there are no infinity, denormals, zeros or negative values. ```c f8E8M0FNU - Exponent bias: 127 - Maximum stored exponent value: 254 (binary 1111'1110) - Maximum unbiased exponent value: 254 - 127 = 127 - Minimum stored exponent value: 0 (binary 0000'0000) - Minimum unbiased exponent value: 0 − 127 = -127 - Doesn't have zero - Doesn't have infinity - NaN is encoded as binary 1111'1111 Additional details: - Zeros cannot be represented - Negative values cannot be represented - Mantissa is always 1 ``` Related PRs: - [PR-107127](https://github.com/llvm/llvm-project/pull/107127) [APFloat] Add APFloat support for E8M0 type - [PR-105573](https://github.com/llvm/llvm-project/pull/105573) [MLIR] Add f6E3M2FN type - was used as a template for this PR - [PR-107999](https://github.com/llvm/llvm-project/pull/107999) [MLIR] Add f6E2M3FN type - [PR-108877](https://github.com/llvm/llvm-project/pull/108877) [MLIR] Add f4E2M1FN type --- mlir/include/mlir-c/BuiltinTypes.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRTypes.cpp | 22 ++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 12 ++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 14 ++++++++++++++ mlir/python/mlir/extras/types.py | 2 ++ 5 files changed, 60 insertions(+) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 6dc25a56b..6875fab7b 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -179,6 +179,16 @@ MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E3M4(MlirType type); /// context. MLIR_CAPI_EXPORTED MlirType mlirFloat8E3M4TypeGet(MlirContext ctx); +/// Returns the typeID of an Float8E8M0FNU type. +MLIR_CAPI_EXPORTED MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID(void); + +/// Checks whether the given type is an f8E8M0FNU type. +MLIR_CAPI_EXPORTED bool mlirTypeIsAFloat8E8M0FNU(MlirType type); + +/// Creates an f8E8M0FNU type in the given context. The type is owned by the +/// context. +MLIR_CAPI_EXPORTED MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx); + /// Returns the typeID of an BFloat16 type. MLIR_CAPI_EXPORTED MlirTypeID mlirBFloat16TypeGetTypeID(void); diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 5a369b5d4..6f192bc4b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -331,6 +331,27 @@ class PyFloat8E3M4Type : public PyConcreteType { } }; +/// Floating Point Type subclass - Float8E8M0FNUType. +class PyFloat8E8M0FNUType + : public PyConcreteType { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E8M0FNUTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E8M0FNUType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); + return PyFloat8E8M0FNUType(context->getRef(), t); + }, + py::arg("context") = py::none(), "Create a float8_e8m0fnu type."); + } +}; + /// Floating Point Type subclass - BF16Type. class PyBF16Type : public PyConcreteType { public: @@ -953,6 +974,7 @@ void mlir::python::populateIRTypes(py::module &m) { PyFloat8E4M3B11FNUZType::bind(m); PyFloat8E5M2FNUZType::bind(m); PyFloat8E3M4Type::bind(m); + PyFloat8E8M0FNUType::bind(m); PyBF16Type::bind(m); PyF16Type::bind(m); PyTF32Type::bind(m); diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index efc1e857a..252ff54af 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -205,6 +205,18 @@ MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { return wrap(FloatType::getFloat8E3M4(unwrap(ctx))); } +MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { + return wrap(Float8E8M0FNUType::getTypeID()); +} + +bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { + return unwrap(type).isFloat8E8M0FNU(); +} + +MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { + return wrap(FloatType::getFloat8E8M0FNU(unwrap(ctx))); +} + MlirTypeID mlirBFloat16TypeGetTypeID() { return wrap(BFloat16Type::getTypeID()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 41ed84e04..fb7efb8cd 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -117,6 +117,7 @@ __all__ = [ "Float8E4M3Type", "Float8E5M2FNUZType", "Float8E5M2Type", + "Float8E8M0FNUType", "FloatAttr", "FloatTF32Type", "FloatType", @@ -1660,6 +1661,19 @@ class Float8E5M2Type(FloatType): @property def typeid(self) -> TypeID: ... +class Float8E8M0FNUType(FloatType): + static_typeid: ClassVar[TypeID] + @staticmethod + def get(context: Context | None = None) -> Float8E8M0FNUType: + """ + Create a float8_e8m0fnu type. + """ + @staticmethod + def isinstance(other: Type) -> bool: ... + def __init__(self, cast_from_type: Type) -> None: ... + @property + def typeid(self) -> TypeID: ... + class FloatAttr(Attribute): static_typeid: ClassVar[TypeID] @staticmethod diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index 5b24a6d52..34eee1edb 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -20,6 +20,7 @@ Float8E4M3FNType, Float8E4M3Type, Float8E5M2Type, + Float8E8M0FNUType, FunctionType, IndexType, IntegerType, @@ -80,6 +81,7 @@ def ui(width): f4E2M1FN = lambda: Float4E2M1FNType.get() f6E2M3FN = lambda: Float6E2M3FNType.get() f6E3M2FN = lambda: Float6E3M2FNType.get() +f8E8M0FNU = lambda: Float8E8M0FNUType.get() none = lambda: NoneType.get() From 941d7790f5f123dbcdf2c2896b0f6e627bf0c557 Mon Sep 17 00:00:00 2001 From: Walter Erquinigo Date: Mon, 7 Oct 2024 17:51:08 -0400 Subject: [PATCH 777/915] [mlir][debuginfo] Add support for subprogram annotations (#110946) LLVM already supports `DW_TAG_LLVM_annotation` entries for subprograms, but this hasn't been surfaced to the LLVM dialect. I'm doing the minimal amount of work to support string-based annotations, which is useful for attaching metadata to functions, which is useful for debuggers to offer features beyond basic DWARF. As LLVM already supports this, this patch is not controversial. --- mlir/include/mlir-c/Dialect/LLVM.h | 7 ++++++- mlir/lib/CAPI/Dialect/LLVM.cpp | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index d6062bed5..0e6434073 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -325,7 +325,12 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDISubprogramAttrGet( MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, MlirAttribute file, unsigned int line, unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, - intptr_t nRetainedNodes, MlirAttribute const *retainedNodes); + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes, + intptr_t nAnnotations, MlirAttribute const *annotations); + +/// Creates a LLVM DIAnnotation attribute. +MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMDIAnnotationAttrGet( + MlirContext ctx, MlirAttribute name, MlirAttribute value); /// Gets the scope from this DISubprogramAttr. MLIR_CAPI_EXPORTED MlirAttribute diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 03b536d7a..c7082445d 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -303,9 +303,14 @@ MlirAttribute mlirLLVMDISubprogramAttrGet( MlirAttribute compileUnit, MlirAttribute scope, MlirAttribute name, MlirAttribute linkageName, MlirAttribute file, unsigned int line, unsigned int scopeLine, uint64_t subprogramFlags, MlirAttribute type, - intptr_t nRetainedNodes, MlirAttribute const *retainedNodes) { + intptr_t nRetainedNodes, MlirAttribute const *retainedNodes, + intptr_t nAnnotations, MlirAttribute const *annotations) { SmallVector nodesStorage; nodesStorage.reserve(nRetainedNodes); + + SmallVector annotationsStorage; + annotationsStorage.reserve(nAnnotations); + return wrap(DISubprogramAttr::get( unwrap(ctx), cast(unwrap(recId)), isRecSelf, cast(unwrap(id)), @@ -316,6 +321,9 @@ MlirAttribute mlirLLVMDISubprogramAttrGet( cast(unwrap(type)), llvm::map_to_vector( unwrapList(nRetainedNodes, retainedNodes, nodesStorage), + [](Attribute a) { return cast(a); }), + llvm::map_to_vector( + unwrapList(nAnnotations, annotations, annotationsStorage), [](Attribute a) { return cast(a); }))); } @@ -375,3 +383,9 @@ MlirAttribute mlirLLVMDIImportedEntityAttrGet( llvm::map_to_vector(unwrapList(nElements, elements, elementsStorage), [](Attribute a) { return cast(a); }))); } + +MlirAttribute mlirLLVMDIAnnotationAttrGet(MlirContext ctx, MlirAttribute name, + MlirAttribute value) { + return wrap(DIAnnotationAttr::get(unwrap(ctx), cast(unwrap(name)), + cast(unwrap(value)))); +} From 7f4712d3ff17cd26c8f87ad4bbac04da77705e53 Mon Sep 17 00:00:00 2001 From: Md Asghar Ahmad Shahid Date: Thu, 10 Oct 2024 21:30:58 +0530 Subject: [PATCH 778/915] [mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (#104783) The main goal of this patch is to extend the semantic of 'linalg.matmul' named op to include per operand transpose semantic while also laying out a way to move ops definition from OpDSL to tablegen. Hence, it is implemented in tablegen. Transpose semantic is as follows. By default 'linalg.matmul' behavior will remain as is. Transpose semantics can be appiled on per input operand by specifying the optional permutation attributes (namely 'permutationA' for 1st input and 'permutationB' for 2nd input) for each operand explicitly as needed. By default, no transpose is mandated for any of the input operand. Example: ``` %val = linalg.matmul ins(%arg0, %arg1 : memref<5x3xf32>, memref<5x7xf32>) outs(%arg2: memref<3x7xf32>) permutationA = [1, 0] permutationB = [0, 1] ``` --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index e4a6ec748..d5e79b4d3 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -383,23 +383,6 @@ def select( O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) -@linalg_structured_op -def matmul( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) - - @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), From 3935cd94e894d218a1541ab1d23d29ad3c07004e Mon Sep 17 00:00:00 2001 From: Emilio Cota Date: Fri, 11 Oct 2024 05:08:23 -0400 Subject: [PATCH 779/915] Revert "[mlir][linalg] Introduce transpose semantic to 'linalg.matmul' ops. (#104783)" This reverts commit 7f4712d3ff17cd26c8f87ad4bbac04da77705e53 and 99c8557, which is a fix-up on top of the former. I'm reverting because this commit broke two tests: mlir/test/python/integration/dialects/linalg/opsrun.py mlir/test/python/integration/dialects/transform.py See https://lab.llvm.org/buildbot/#/builders/138/builds/4872 I'm not familiar with the tests, so I'm leaving it to the original author to either remove or adapt the broken tests, as discussed here: https://github.com/llvm/llvm-project/pull/104783#issuecomment-2406390905 --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index d5e79b4d3..e4a6ec748 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -383,6 +383,23 @@ def select( O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) +@linalg_structured_op +def matmul( + A=TensorDef(T1, S.M, S.K), + B=TensorDef(T2, S.K, S.N), + C=TensorDef(U, S.M, S.N, output=True), + cast=TypeFnAttrDef(default=TypeFn.cast_signed), +): + """Performs a matrix multiplication of two 2D inputs. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + domain(D.m, D.n, D.k) + implements(ContractionOpInterface) + C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) + + @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), From 4bc3cd82786ffedb332e0ed4eb76d854ceffad3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andrzej=20Warzy=C5=84ski?= Date: Tue, 15 Oct 2024 19:24:43 +0100 Subject: [PATCH 780/915] [mlir][td] Rename pack_paddings in structured.pad (#111036) The pack_paddings attribute in the structure.pad TD Op is used to set the `nofold` attribute in the generated tensor.pad Op. The current name is confusing and suggests that there's a relation with the tensor.pack Op. This patch renames it as `nofold_flags` to better match the actual usage. --- mlir/python/mlir/dialects/transform/structured.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 41051c0d5..f6111f516 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -377,7 +377,7 @@ def __init__( pad_to_multiple_of: Optional[Union[DynamicIndexList, ArrayAttr]] = None, padding_values: Optional[Union[ArrayAttr, Sequence[Attribute]]] = None, padding_dimensions: OptionalIntList = None, - pack_paddings: OptionalIntList = None, + nofold_flags: OptionalIntList = None, transpose_paddings: Optional[ Union[ArrayAttr, Sequence[Union[ArrayAttr, IntOrAttrList]]] ] = None, @@ -407,7 +407,7 @@ def __init__( padding_values=padding_values, padding_dimensions=padding_dimensions, static_pad_to_multiple_of=static_pad_to_multiple_of, - pack_paddings=pack_paddings, + nofold_flags=nofold_flags, transpose_paddings=transpose_paddings, copy_back_op=copy_back_op, loc=loc, From 017ceb2d858599027fc9403d749fad22b0b93cdf Mon Sep 17 00:00:00 2001 From: Frank Schlimbach Date: Fri, 18 Oct 2024 22:20:47 +0200 Subject: [PATCH 781/915] eliminating g++ warnings (#105520) Eliminating g++ warnings. Mostly declaring "[[maybe_unused]]", adding return statements where missing and fixing casts. @rengolin --------- Co-authored-by: Benjamin Maxwell Co-authored-by: Renato Golin --- mlir/lib/CAPI/IR/IR.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 5eb531b70..e7e6b11c8 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -736,6 +736,7 @@ static mlir::WalkResult unwrap(MlirWalkResult result) { case MlirWalkResultSkip: return mlir::WalkResult::skip(); } + llvm_unreachable("unknown result in WalkResult::unwrap"); } void mlirOperationWalk(MlirOperation op, MlirOperationWalkCallback callback, From dd7c7f3e86c1060e40edc3b115ec49bbfb4cf2fd Mon Sep 17 00:00:00 2001 From: Felix Schneider Date: Sat, 19 Oct 2024 18:25:27 +0200 Subject: [PATCH 782/915] [mlir][linalg] Add quantized conv2d operator with FCHW,NCHW order (#107740) This patch adds a quantized version of the `linalg.conv2d_nchw_fchw` Op. This is the "channel-first" ordering typically used by PyTorch and others. --- .../linalg/opdsl/ops/core_named_ops.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index e4a6ec748..b45fecd0e 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -876,6 +876,35 @@ def conv_2d_nhwc_fhwc_q( ) * (TypeFn.cast_signed(U, K[D.f, D.kh, D.kw, D.c]) - TypeFn.cast_signed(U, KZp)) +@linalg_structured_op +def conv_2d_nchw_fchw_q( + I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), + K=TensorDef(T2, S.F, S.C, S.KH, S.KW), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.F, S.OH, S.OW, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D convolution with zero point offsets. + + Layout: + * Input: NCHW. + * Kernel: FCHW. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.f, D.oh, D.ow, D.c, D.kh, D.kw) + O[D.n, D.f, D.oh, D.ow] += ( + TypeFn.cast_signed( + U, I[D.n, D.c, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW] + ) + - TypeFn.cast_signed(U, IZp) + ) * (TypeFn.cast_signed(U, K[D.f, D.c, D.kh, D.kw]) - TypeFn.cast_signed(U, KZp)) + @linalg_structured_op def conv_2d_nchw_fchw( I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW), From bd565ed12a7bf5eca61beaaccd0b85854e16d67b Mon Sep 17 00:00:00 2001 From: Marius Brehler Date: Thu, 31 Oct 2024 17:39:26 +0100 Subject: [PATCH 783/915] [mlir][python] Raise maximum allowed version (#114050) Raises the maximum allowed versions to more recent versions, which is a basic enabler to install them in a venv using Python 3.13. --- mlir/python/requirements.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index d1b5418cc..eeaac2746 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -numpy>=1.19.5, <=1.26 -pybind11>=2.9.0, <=2.10.3 +numpy>=1.19.5, <=2.1.2 +pybind11>=2.9.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 -ml_dtypes>=0.1.0, <=0.4.0 # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 From 9db1b78e73bc00d55be30d3781ea210a2e350085 Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Sat, 2 Nov 2024 07:39:48 +0100 Subject: [PATCH 784/915] [MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (#113064) Currently it is unsupported to: 1. Convert a `MlirAttribute` with type `i1` to a numpy array 2. Convert a boolean numpy array to a `MlirAttribute` Currently the entire Python application violently crashes with a quite poor error message https://github.com/pybind/pybind11/issues/3336 The complication handling these conversions, is that `MlirAttribute` represent booleans as a bit-packed `i1` type, whereas numpy represents booleans as a byte array with 8 bit used per boolean. This PR proposes the following approach: 1. When converting a `i1` typed `MlirAttribute` to a numpy array, we can not directly use the underlying raw data backing the `MlirAttribute` as a buffer to Python, as done for other types. Instead, a copy of the data is generated using numpy's unpackbits function, and the result is send back to Python. 2. When constructing a `MlirAttribute` from a numpy array, first the python data is read as a `uint8_t` to get it converted to the endianess used internally in mlir. Then the booleans are bitpacked using numpy's bitpack function, and the bitpacked array is saved as the `MlirAttribute` representation. Please note that I am not sure if this approach is the desired solution. I'd appreciate any feedback. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 278 ++++++++++++++-------- 1 file changed, 181 insertions(+), 97 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index ead81a76c..c8883c0d8 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -13,6 +13,7 @@ #include "IRModule.h" #include "PybindUtils.h" +#include #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -757,103 +758,10 @@ class PyDenseElementsAttribute throw py::error_already_set(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); - SmallVector shape; - if (explicitShape) { - shape.append(explicitShape->begin(), explicitShape->end()); - } else { - shape.append(view.shape, view.shape + view.ndim); - } - MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - - // Detect format codes that are suitable for bulk loading. This includes - // all byte aligned integer and floating point types up to 8 bytes. - // Notably, this excludes, bool (which needs to be bit-packed) and - // other exotics which do not have a direct representation in the buffer - // protocol (i.e. complex, etc). - std::optional bulkLoadElementType; - if (explicitType) { - bulkLoadElementType = *explicitType; - } else { - std::string_view format(view.format); - if (format == "f") { - // f32 - assert(view.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (format == "d") { - // f64 - assert(view.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (format == "e") { - // f16 - assert(view.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (isSignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (view.itemsize == 8) { - // i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (view.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (!bulkLoadElementType) { - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - std::string(format)); - } - } - - MlirType shapedType; - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); - } - shapedType = *bulkLoadElementType; - } else { - shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); - } - size_t rawBufferSize = view.len; - MlirAttribute attr = - mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); + MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, + explicitShape, context); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -963,6 +871,13 @@ class PyDenseElementsAttribute // unsigned i16 return bufferInfo(shapedType); } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 1) { + // i1 / bool + // We can not send the buffer directly back to Python, because the i1 + // values are bitpacked within MLIR. We call numpy's unpackbits function + // to convert the bytes. + return getBooleanBufferFromBitpackedAttribute(); } // TODO: Currently crashes the program. @@ -1016,14 +931,183 @@ class PyDenseElementsAttribute code == 'q'; } + static MlirType + getShapedType(std::optional bulkLoadElementType, + std::optional> explicitShape, + Py_buffer &view) { + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); + } + + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + return *bulkLoadElementType; + } else { + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + } + + static MlirAttribute getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + std::optional> explicitShape, MlirContext &context) { + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes exotics types which do not have a direct + // representation in the buffer protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (format == "?") { + // i1 + // The i1 type needs to be bit-packed, so we will handle it seperately + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, + context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); + } + } + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); + } + } + + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); + } + + // There is a complication for boolean numpy arrays, as numpy represents them + // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans + // per byte. + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context) { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a bit-packed MLIR attribute is " + "unsupported on big-endian systems"); + } + + py::array_t unpackedArray(view.len, + static_cast(view.buf)); + + py::module numpy = py::module::import("numpy"); + py::object packbits_func = numpy.attr("packbits"); + py::object packed_booleans = + packbits_func(unpackedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = packed_booleans.cast().request(); + + MlirType bitpackedType = + getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, + pythonBuffer.ptr); + } + + // This does the opposite transformation of + // `getBitpackedAttributeFromBooleanBuffer` + py::buffer_info getBooleanBufferFromBitpackedAttribute() { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a numpy array from a MLIR attribute " + "is unsupported on big-endian systems"); + } + + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); + uint8_t *bitpackedData = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + py::array_t packedArray(numBitpackedBytes, bitpackedData); + + py::module numpy = py::module::import("numpy"); + py::object unpackbits_func = numpy.attr("unpackbits"); + py::object unpacked_booleans = + unpackbits_func(packedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = + unpacked_booleans.cast().request(); + + MlirType shapedType = mlirAttributeGetType(*this); + return bufferInfo(shapedType, (bool *)pythonBuffer.ptr, "?"); + } + template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. - // Buffer is configured for read-only access below. + // Buffer is configured for read-only access inside the `bufferInfo` call. Type *data = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); + return bufferInfo(shapedType, data, explicitFormat); + } + + template + py::buffer_info bufferInfo(MlirType shapedType, Type *data, + const char *explicitFormat = nullptr) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) From 44db3dd2fcbc9c0aa43a9c9e3a6b94282a2c4d54 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 4 Nov 2024 20:19:18 -0500 Subject: [PATCH 785/915] [mlir:python] Change PyOperation::create to actually return a PyOperation. (#114542) In the tablegen-generated Python bindings, we typically see a pattern like: ``` class ConstantOp(_ods_ir.OpView): ... def __init__(self, value, *, loc=None, ip=None): ... super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) ``` i.e., the generated code calls `OpView.__init__()` with the output of `build_generic`. The purpose of `OpView` is to wrap another operation object, and `OpView.__init__` can accept any `PyOperationBase` subclass, and presumably the intention is that `build_generic` returns a `PyOperation`, so the user ends up with a `PyOpView` wrapping a `PyOperation`. However, `PyOpView::buildGeneric` calls `PyOperation::create`, which does not just build a PyOperation, but it also calls `createOpView` to wrap that operation in a subclass of `PyOpView` and returns that view. But that's rather pointless: we called this code from the constructor of an `OpView` subclass, so we already have a view object ready to go; we don't need to build another one! If we change `PyOperation::create` to return the underlying `PyOperation`, rather than a view wrapper, we can save allocating a useless `PyOpView` object for each ODS-generated Python object. This saves approximately 1.5s of Python time in a JAX LLM benchmark that generates a mixture of upstream dialects and StableHLO. --- mlir/lib/Bindings/Python/IRCore.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c12f75e7d..3562ff382 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1534,7 +1534,7 @@ py::object PyOperation::create(const std::string &name, PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); - return created->createOpView(); + return created.getObject(); } py::object PyOperation::clone(const py::object &maybeIp) { From dae317cfae17d974eca702c3ed8a378faa73d0a4 Mon Sep 17 00:00:00 2001 From: Dmitri Gribenko Date: Tue, 5 Nov 2024 15:48:13 +0100 Subject: [PATCH 786/915] Revert "[MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (#113064)" This reverts commit 9db1b78e73bc00d55be30d3781ea210a2e350085. There is an ASan issue here, see the discussion on https://github.com/llvm/llvm-project/pull/113064. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 278 ++++++++-------------- 1 file changed, 97 insertions(+), 181 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index c8883c0d8..ead81a76c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -13,7 +13,6 @@ #include "IRModule.h" #include "PybindUtils.h" -#include #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -758,10 +757,103 @@ class PyDenseElementsAttribute throw py::error_already_set(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); + } + MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, - explicitShape, context); + + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes, bool (which needs to be bit-packed) and + // other exotics which do not have a direct representation in the buffer + // protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); + } + } + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); + } + } + + MlirType shapedType; + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + shapedType = *bulkLoadElementType; + } else { + shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + size_t rawBufferSize = view.len; + MlirAttribute attr = + mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -871,13 +963,6 @@ class PyDenseElementsAttribute // unsigned i16 return bufferInfo(shapedType); } - } else if (mlirTypeIsAInteger(elementType) && - mlirIntegerTypeGetWidth(elementType) == 1) { - // i1 / bool - // We can not send the buffer directly back to Python, because the i1 - // values are bitpacked within MLIR. We call numpy's unpackbits function - // to convert the bytes. - return getBooleanBufferFromBitpackedAttribute(); } // TODO: Currently crashes the program. @@ -931,183 +1016,14 @@ class PyDenseElementsAttribute code == 'q'; } - static MlirType - getShapedType(std::optional bulkLoadElementType, - std::optional> explicitShape, - Py_buffer &view) { - SmallVector shape; - if (explicitShape) { - shape.append(explicitShape->begin(), explicitShape->end()); - } else { - shape.append(view.shape, view.shape + view.ndim); - } - - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); - } - return *bulkLoadElementType; - } else { - MlirAttribute encodingAttr = mlirAttributeGetNull(); - return mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); - } - } - - static MlirAttribute getAttributeFromBuffer( - Py_buffer &view, bool signless, std::optional explicitType, - std::optional> explicitShape, MlirContext &context) { - // Detect format codes that are suitable for bulk loading. This includes - // all byte aligned integer and floating point types up to 8 bytes. - // Notably, this excludes exotics types which do not have a direct - // representation in the buffer protocol (i.e. complex, etc). - std::optional bulkLoadElementType; - if (explicitType) { - bulkLoadElementType = *explicitType; - } else { - std::string_view format(view.format); - if (format == "f") { - // f32 - assert(view.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (format == "d") { - // f64 - assert(view.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (format == "e") { - // f16 - assert(view.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (format == "?") { - // i1 - // The i1 type needs to be bit-packed, so we will handle it seperately - return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, - context); - } else if (isSignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (view.itemsize == 8) { - // i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (view.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (!bulkLoadElementType) { - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - std::string(format)); - } - } - - MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); - return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); - } - - // There is a complication for boolean numpy arrays, as numpy represents them - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans - // per byte. - static MlirAttribute getBitpackedAttributeFromBooleanBuffer( - Py_buffer &view, std::optional> explicitShape, - MlirContext &context) { - if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a bit-packed MLIR attribute is " - "unsupported on big-endian systems"); - } - - py::array_t unpackedArray(view.len, - static_cast(view.buf)); - - py::module numpy = py::module::import("numpy"); - py::object packbits_func = numpy.attr("packbits"); - py::object packed_booleans = - packbits_func(unpackedArray, "bitorder"_a = "little"); - py::buffer_info pythonBuffer = packed_booleans.cast().request(); - - MlirType bitpackedType = - getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); - return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, - pythonBuffer.ptr); - } - - // This does the opposite transformation of - // `getBitpackedAttributeFromBooleanBuffer` - py::buffer_info getBooleanBufferFromBitpackedAttribute() { - if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a numpy array from a MLIR attribute " - "is unsupported on big-endian systems"); - } - - int64_t numBooleans = mlirElementsAttrGetNumElements(*this); - int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); - uint8_t *bitpackedData = static_cast( - const_cast(mlirDenseElementsAttrGetRawData(*this))); - py::array_t packedArray(numBitpackedBytes, bitpackedData); - - py::module numpy = py::module::import("numpy"); - py::object unpackbits_func = numpy.attr("unpackbits"); - py::object unpacked_booleans = - unpackbits_func(packedArray, "bitorder"_a = "little"); - py::buffer_info pythonBuffer = - unpacked_booleans.cast().request(); - - MlirType shapedType = mlirAttributeGetType(*this); - return bufferInfo(shapedType, (bool *)pythonBuffer.ptr, "?"); - } - template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { + intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. - // Buffer is configured for read-only access inside the `bufferInfo` call. + // Buffer is configured for read-only access below. Type *data = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); - return bufferInfo(shapedType, data, explicitFormat); - } - - template - py::buffer_info bufferInfo(MlirType shapedType, Type *data, - const char *explicitFormat = nullptr) { - intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the shape for the buffer_info. SmallVector shape; for (intptr_t i = 0; i < rank; ++i) From d8f8d997de20842b86b441a8c3938dc765a5899a Mon Sep 17 00:00:00 2001 From: Md Asghar Ahmad Shahid Date: Thu, 7 Nov 2024 20:21:02 +0530 Subject: [PATCH 787/915] [MLIR][Linalg] Re-land linalg.matmul move to ODS. + Remove/update failing obsolete OpDSL tests. (#115319) The earlier PR(https://github.com/llvm/llvm-project/pull/104783) which introduces transpose and broadcast semantic to linalg.matmul was reverted due to two failing OpDSL test for linalg.matmul. Since linalg.matmul is now defined using TableGen ODS instead of Python-based OpDSL, these test started failing and needs to be removed/updated. This commit removes/updates the failing obsolete tests from below files. All other files were part of earlier PR and just cherry picked. "mlir/test/python/integration/dialects/linalg/opsrun.py" "mlir/test/python/integration/dialects/transform.py" --------- Co-authored-by: Renato Golin --- .../dialects/linalg/opdsl/ops/core_named_ops.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index b45fecd0e..5c1c984b1 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -383,23 +383,6 @@ def select( O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None]) -@linalg_structured_op -def matmul( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Performs a matrix multiplication of two 2D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n]) - - @linalg_structured_op def quantized_matmul( A=TensorDef(T1, S.M, S.K), From 7dd461a04ecd81df65c6585e339a11603212c2a3 Mon Sep 17 00:00:00 2001 From: stefankoncarevic Date: Fri, 8 Nov 2024 18:23:17 +0100 Subject: [PATCH 788/915] [mlir][linalg] Add Grouped Convolution Ops: conv_2d_nhwgc_gfhwc and conv_2d_nhwgc_gfhwc_q (#108192) This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops. - Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp. - Updated the conversion process to use these new ops for tosa group conv2d operations. --- .../linalg/opdsl/ops/core_named_ops.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 5c1c984b1..c95cd5eec 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -964,6 +964,67 @@ def conv_2d_ngchw_gfchw( ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.c, D.kh, D.kw]) +@linalg_structured_op +def conv_2d_nhwgc_gfhwc( + I=TensorDef( + T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C + ), + K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), + O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution. + + Layout: + * Input: NHWGC. + * Kernel: GFHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.g, D.fg] += TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] + ) * TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) + + +@linalg_structured_op +def conv_2d_nhwgc_gfhwc_q( + I=TensorDef( + T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.G, S.C + ), + K=TensorDef(T2, S.G, S.FG, S.KH, S.KW, S.C), + IZp=ScalarDef(I32), + KZp=ScalarDef(I32), + O=TensorDef(U, S.N, S.OH, S.OW, S.G, S.FG, output=True), + strides=IndexAttrDef(S.SH, S.SW, default=[1, 1]), + dilations=IndexAttrDef(S.DH, S.DW, default=[1, 1]), +): + """Performs 2-D grouped convolution with zero point offsets. + + Layout: + * Input: NHWGC. + * Kernel: GFHWC. + + Numeric casting is performed on the operands to the inner multiply, promoting + them to the same data type as the accumulator/output. This includes the zero + point offsets common to quantized operations. + """ + implements(ConvolutionOpInterface) + domain(D.n, D.oh, D.ow, D.g, D.fg, D.kh, D.kw, D.c) + O[D.n, D.oh, D.ow, D.g, D.fg] += ( + TypeFn.cast_signed( + U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.g, D.c] + ) + - TypeFn.cast_signed(U, IZp) + ) * ( + TypeFn.cast_signed(U, K[D.g, D.fg, D.kh, D.kw, D.c]) + - TypeFn.cast_signed(U, KZp) + ) + + @linalg_structured_op def conv_2d_ngchw_gfchw_q( I=TensorDef( From 9f5cb2a30a577d162e12f407b435d5792fa7ae1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Mon, 11 Nov 2024 09:26:15 +0100 Subject: [PATCH 789/915] [mlir][python] Make types in register_(dialect|operation) more narrow. (#115307) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes the `pyClass`/`dialectClass` arguments of the pybind11 functions `register_dialect` and `register_operation` as well as their return types more narrow, concretely, a `py::type` instead of a `py::object`. As the name of the arguments indicate, they have to be called with a type instance (a "class"). The PR also updates the typing stubs of these functions (in the corresponding `.pyi` file), such that static type checkers are aware of the changed type. With the previous typing information, `pyright` raised errors on code generated by tablegen. Signed-off-by: Ingo Müller --- mlir/lib/Bindings/Python/MainModule.cpp | 6 +++--- mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 8da1ab16a..7c2702190 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -58,7 +58,7 @@ PYBIND11_MODULE(_mlir, m) { // Registration decorators. m.def( "register_dialect", - [](py::object pyClass) { + [](py::type pyClass) { std::string dialectNamespace = pyClass.attr("DIALECT_NAMESPACE").cast(); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); @@ -68,9 +68,9 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::object &dialectClass, bool replace) -> py::cpp_function { + [](const py::type &dialectClass, bool replace) -> py::cpp_function { return py::cpp_function( - [dialectClass, replace](py::object opClass) -> py::object { + [dialectClass, replace](py::type opClass) -> py::type { std::string operationName = opClass.attr("OPERATION_NAME").cast(); PyGlobals::get().registerOperationImpl(operationName, opClass, diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 42694747e..03449b70b 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -8,5 +8,5 @@ class _Globals: def append_dialect_search_prefix(self, module_name: str) -> None: ... def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... -def register_dialect(dialect_class: type) -> object: ... -def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ... +def register_dialect(dialect_class: type) -> type: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> type: ... From cf24b887c087c913eb4b5aea1165a7a52a2e53d9 Mon Sep 17 00:00:00 2001 From: Kasper Nielsen Date: Tue, 12 Nov 2024 22:23:10 -0800 Subject: [PATCH 790/915] [MLIR,Python] Support converting boolean numpy arrays to and from mlir attributes (unrevert) (#115481) This PR re-introduces the functionality of https://github.com/llvm/llvm-project/pull/113064, which was reverted in https://github.com/llvm/llvm-project/commit/dae317cfae17d974eca702c3ed8a378faa73d0a4 due to memory lifetime issues. Notice that I was not able to re-produce the ASan results myself, so I have not been able to verify that this PR really fixes the issue. --- Currently it is unsupported to: 1. Convert a MlirAttribute with type i1 to a numpy array 2. Convert a boolean numpy array to a MlirAttribute Currently the entire Python application violently crashes with a quite poor error message https://github.com/pybind/pybind11/issues/3336 The complication handling these conversions, is that MlirAttribute represent booleans as a bit-packed i1 type, whereas numpy represents booleans as a byte array with 8 bit used per boolean. This PR proposes the following approach: 1. When converting a i1 typed MlirAttribute to a numpy array, we can not directly use the underlying raw data backing the MlirAttribute as a buffer to Python, as done for other types. Instead, a copy of the data is generated using numpy's unpackbits function, and the result is send back to Python. 2. When constructing a MlirAttribute from a numpy array, first the python data is read as a uint8_t to get it converted to the endianess used internally in mlir. Then the booleans are bitpacked using numpy's bitpack function, and the bitpacked array is saved as the MlirAttribute representation. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 290 +++++++++++++++------- 1 file changed, 195 insertions(+), 95 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index ead81a76c..417c66b91 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -13,6 +13,7 @@ #include "IRModule.h" #include "PybindUtils.h" +#include #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -757,103 +758,10 @@ class PyDenseElementsAttribute throw py::error_already_set(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); - SmallVector shape; - if (explicitShape) { - shape.append(explicitShape->begin(), explicitShape->end()); - } else { - shape.append(view.shape, view.shape + view.ndim); - } - MlirAttribute encodingAttr = mlirAttributeGetNull(); MlirContext context = contextWrapper->get(); - - // Detect format codes that are suitable for bulk loading. This includes - // all byte aligned integer and floating point types up to 8 bytes. - // Notably, this excludes, bool (which needs to be bit-packed) and - // other exotics which do not have a direct representation in the buffer - // protocol (i.e. complex, etc). - std::optional bulkLoadElementType; - if (explicitType) { - bulkLoadElementType = *explicitType; - } else { - std::string_view format(view.format); - if (format == "f") { - // f32 - assert(view.itemsize == 4 && "mismatched array itemsize"); - bulkLoadElementType = mlirF32TypeGet(context); - } else if (format == "d") { - // f64 - assert(view.itemsize == 8 && "mismatched array itemsize"); - bulkLoadElementType = mlirF64TypeGet(context); - } else if (format == "e") { - // f16 - assert(view.itemsize == 2 && "mismatched array itemsize"); - bulkLoadElementType = mlirF16TypeGet(context); - } else if (isSignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeSignedGet(context, 32); - } else if (view.itemsize == 8) { - // i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeSignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeSignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeSignedGet(context, 16); - } - } else if (isUnsignedIntegerFormat(format)) { - if (view.itemsize == 4) { - // unsigned i32 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 32) - : mlirIntegerTypeUnsignedGet(context, 32); - } else if (view.itemsize == 8) { - // unsigned i64 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 64) - : mlirIntegerTypeUnsignedGet(context, 64); - } else if (view.itemsize == 1) { - // i8 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 8) - : mlirIntegerTypeUnsignedGet(context, 8); - } else if (view.itemsize == 2) { - // i16 - bulkLoadElementType = signless - ? mlirIntegerTypeGet(context, 16) - : mlirIntegerTypeUnsignedGet(context, 16); - } - } - if (!bulkLoadElementType) { - throw std::invalid_argument( - std::string("unimplemented array format conversion from format: ") + - std::string(format)); - } - } - - MlirType shapedType; - if (mlirTypeIsAShaped(*bulkLoadElementType)) { - if (explicitShape) { - throw std::invalid_argument("Shape can only be specified explicitly " - "when the type is not a shaped type."); - } - shapedType = *bulkLoadElementType; - } else { - shapedType = mlirRankedTensorTypeGet(shape.size(), shape.data(), - *bulkLoadElementType, encodingAttr); - } - size_t rawBufferSize = view.len; - MlirAttribute attr = - mlirDenseElementsAttrRawBufferGet(shapedType, rawBufferSize, view.buf); + MlirAttribute attr = getAttributeFromBuffer(view, signless, explicitType, + explicitShape, context); if (mlirAttributeIsNull(attr)) { throw std::invalid_argument( "DenseElementsAttr could not be constructed from the given buffer. " @@ -963,6 +871,13 @@ class PyDenseElementsAttribute // unsigned i16 return bufferInfo(shapedType); } + } else if (mlirTypeIsAInteger(elementType) && + mlirIntegerTypeGetWidth(elementType) == 1) { + // i1 / bool + // We can not send the buffer directly back to Python, because the i1 + // values are bitpacked within MLIR. We call numpy's unpackbits function + // to convert the bytes. + return getBooleanBufferFromBitpackedAttribute(); } // TODO: Currently crashes the program. @@ -1016,6 +931,191 @@ class PyDenseElementsAttribute code == 'q'; } + static MlirType + getShapedType(std::optional bulkLoadElementType, + std::optional> explicitShape, + Py_buffer &view) { + SmallVector shape; + if (explicitShape) { + shape.append(explicitShape->begin(), explicitShape->end()); + } else { + shape.append(view.shape, view.shape + view.ndim); + } + + if (mlirTypeIsAShaped(*bulkLoadElementType)) { + if (explicitShape) { + throw std::invalid_argument("Shape can only be specified explicitly " + "when the type is not a shaped type."); + } + return *bulkLoadElementType; + } else { + MlirAttribute encodingAttr = mlirAttributeGetNull(); + return mlirRankedTensorTypeGet(shape.size(), shape.data(), + *bulkLoadElementType, encodingAttr); + } + } + + static MlirAttribute getAttributeFromBuffer( + Py_buffer &view, bool signless, std::optional explicitType, + std::optional> explicitShape, MlirContext &context) { + // Detect format codes that are suitable for bulk loading. This includes + // all byte aligned integer and floating point types up to 8 bytes. + // Notably, this excludes exotics types which do not have a direct + // representation in the buffer protocol (i.e. complex, etc). + std::optional bulkLoadElementType; + if (explicitType) { + bulkLoadElementType = *explicitType; + } else { + std::string_view format(view.format); + if (format == "f") { + // f32 + assert(view.itemsize == 4 && "mismatched array itemsize"); + bulkLoadElementType = mlirF32TypeGet(context); + } else if (format == "d") { + // f64 + assert(view.itemsize == 8 && "mismatched array itemsize"); + bulkLoadElementType = mlirF64TypeGet(context); + } else if (format == "e") { + // f16 + assert(view.itemsize == 2 && "mismatched array itemsize"); + bulkLoadElementType = mlirF16TypeGet(context); + } else if (format == "?") { + // i1 + // The i1 type needs to be bit-packed, so we will handle it seperately + return getBitpackedAttributeFromBooleanBuffer(view, explicitShape, + context); + } else if (isSignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeSignedGet(context, 32); + } else if (view.itemsize == 8) { + // i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeSignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeSignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeSignedGet(context, 16); + } + } else if (isUnsignedIntegerFormat(format)) { + if (view.itemsize == 4) { + // unsigned i32 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 32) + : mlirIntegerTypeUnsignedGet(context, 32); + } else if (view.itemsize == 8) { + // unsigned i64 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 64) + : mlirIntegerTypeUnsignedGet(context, 64); + } else if (view.itemsize == 1) { + // i8 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 8) + : mlirIntegerTypeUnsignedGet(context, 8); + } else if (view.itemsize == 2) { + // i16 + bulkLoadElementType = signless + ? mlirIntegerTypeGet(context, 16) + : mlirIntegerTypeUnsignedGet(context, 16); + } + } + if (!bulkLoadElementType) { + throw std::invalid_argument( + std::string("unimplemented array format conversion from format: ") + + std::string(format)); + } + } + + MlirType type = getShapedType(bulkLoadElementType, explicitShape, view); + return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); + } + + // There is a complication for boolean numpy arrays, as numpy represents them + // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans + // per byte. + static MlirAttribute getBitpackedAttributeFromBooleanBuffer( + Py_buffer &view, std::optional> explicitShape, + MlirContext &context) { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a bit-packed MLIR attribute is " + "unsupported on big-endian systems"); + } + + py::array_t unpackedArray(view.len, + static_cast(view.buf)); + + py::module numpy = py::module::import("numpy"); + py::object packbitsFunc = numpy.attr("packbits"); + py::object packedBooleans = + packbitsFunc(unpackedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = packedBooleans.cast().request(); + + MlirType bitpackedType = + getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); + assert(pythonBuffer.itemsize == 1 && "Packbits must return uint8"); + // Notice that `mlirDenseElementsAttrRawBufferGet` copies the memory of + // packedBooleans, hence the MlirAttribute will remain valid even when + // packedBooleans get reclaimed by the end of the function. + return mlirDenseElementsAttrRawBufferGet(bitpackedType, pythonBuffer.size, + pythonBuffer.ptr); + } + + // This does the opposite transformation of + // `getBitpackedAttributeFromBooleanBuffer` + py::buffer_info getBooleanBufferFromBitpackedAttribute() { + if (llvm::endianness::native != llvm::endianness::little) { + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a numpy array from a MLIR attribute " + "is unsupported on big-endian systems"); + } + + int64_t numBooleans = mlirElementsAttrGetNumElements(*this); + int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); + uint8_t *bitpackedData = static_cast( + const_cast(mlirDenseElementsAttrGetRawData(*this))); + py::array_t packedArray(numBitpackedBytes, bitpackedData); + + py::module numpy = py::module::import("numpy"); + py::object unpackbitsFunc = numpy.attr("unpackbits"); + py::object equalFunc = numpy.attr("equal"); + py::object reshapeFunc = numpy.attr("reshape"); + py::array unpackedBooleans = + unpackbitsFunc(packedArray, "bitorder"_a = "little"); + + // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. + // We need to: + // 1. Slice away the padded bits + // 2. Make the boolean array have the correct shape + // 3. Convert the array to a boolean array + unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; + unpackedBooleans = equalFunc(unpackedBooleans, 1); + + std::vector shape; + MlirType shapedType = mlirAttributeGetType(*this); + intptr_t rank = mlirShapedTypeGetRank(shapedType); + for (intptr_t i = 0; i < rank; ++i) { + shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + } + unpackedBooleans = reshapeFunc(unpackedBooleans, shape); + + // Make sure the returned py::buffer_view claims ownership of the data in + // `pythonBuffer` so it remains valid when Python reads it + py::buffer pythonBuffer = unpackedBooleans.cast(); + return pythonBuffer.request(); + } + template py::buffer_info bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { From bf8dadcfd15250ab89a949e385ac72dcf2fcc5d4 Mon Sep 17 00:00:00 2001 From: Amy Wang Date: Wed, 13 Nov 2024 16:27:46 -0500 Subject: [PATCH 791/915] [MLIR][Python] Python binding support for AffineIfOp (#108323) Fix the AffineIfOp's default builder such that it takes in an IntegerSetAttr. AffineIfOp has skipDefaultBuilders=1 which effectively skips the creation of the default AffineIfOp::builder on the C++ side. (AffineIfOp has two custom OpBuilder defined in the extraClassDeclaration.) However, on the python side, _affine_ops_gen.py shows that the default builder is being created, but it does not accept IntegerSet and thus is useless. This fix at line 411 makes the default python AffineIfOp builder take in an IntegerSet input and does not impact the C++ side of things. --- mlir/python/mlir/dialects/affine.py | 58 +++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/mlir/python/mlir/dialects/affine.py b/mlir/python/mlir/dialects/affine.py index 913cea611..7641d36e3 100644 --- a/mlir/python/mlir/dialects/affine.py +++ b/mlir/python/mlir/dialects/affine.py @@ -156,3 +156,61 @@ def for_( yield iv, iter_args[0] else: yield iv + + +@_ods_cext.register_operation(_Dialect, replace=True) +class AffineIfOp(AffineIfOp): + """Specialization for the Affine if op class.""" + + def __init__( + self, + cond: IntegerSet, + results_: Optional[Type] = None, + *, + cond_operands: Optional[_VariadicResultValueT] = None, + has_else: bool = False, + loc=None, + ip=None, + ): + """Creates an Affine `if` operation. + + - `cond` is the integer set used to determine which regions of code + will be executed. + - `results` are the list of types to be yielded by the operand. + - `cond_operands` is the list of arguments to substitute the + dimensions, then symbols in the `cond` integer set expression to + determine whether they are in the set. + - `has_else` determines whether the affine if operation has the else + branch. + """ + if results_ is None: + results_ = [] + if cond_operands is None: + cond_operands = [] + + if cond.n_inputs != len(cond_operands): + raise ValueError( + f"expected {cond.n_inputs} condition operands, got {len(cond_operands)}" + ) + + operands = [] + operands.extend(cond_operands) + results = [] + results.extend(results_) + + super().__init__(results, cond_operands, cond) + self.regions[0].blocks.append(*[]) + if has_else: + self.regions[1].blocks.append(*[]) + + @property + def then_block(self) -> Block: + """Returns the then block of the if operation.""" + return self.regions[0].blocks[0] + + @property + def else_block(self) -> Optional[Block]: + """Returns the else block of the if operation.""" + if len(self.regions[1].blocks) == 0: + return None + return self.regions[1].blocks[0] From 456c26f7773da893c78dc06730b19bdf9586beb2 Mon Sep 17 00:00:00 2001 From: "Jinyun (Joey) Ye" Date: Fri, 15 Nov 2024 02:53:34 +0000 Subject: [PATCH 792/915] [MLIR][Transform] Consolidate result of structured.split into one list (#111171) Follow-up a review comment from https://github.com/llvm/llvm-project/pull/82792#discussion_r1604925239 as a separate PR: E.g.: ``` %0:2 = transform.structured.split ``` is changed to ``` %t = transform.structured.split %0:2 = transform.split_handle %t ``` --- mlir/python/mlir/dialects/transform/structured.py | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index f6111f516..9121aa8e4 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -445,7 +445,6 @@ def __init__( dynamic_chunk_sizes = chunk_sizes super().__init__( - target.type, target.type, target, dimension=dimension, From 4fd069cb63ea78005633c360c1ddddf21b65c8b5 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 19 Nov 2024 11:00:35 +0900 Subject: [PATCH 793/915] [mlir][python] Add `T.tf32` and missing tests for `tf32` (#116725) --- mlir/python/mlir/extras/types.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/python/mlir/extras/types.py b/mlir/python/mlir/extras/types.py index 34eee1edb..b875d639e 100644 --- a/mlir/python/mlir/extras/types.py +++ b/mlir/python/mlir/extras/types.py @@ -21,6 +21,7 @@ Float8E4M3Type, Float8E5M2Type, Float8E8M0FNUType, + FloatTF32Type, FunctionType, IndexType, IntegerType, @@ -70,6 +71,7 @@ def ui(width): f16 = lambda: F16Type.get() f32 = lambda: F32Type.get() +tf32 = lambda: FloatTF32Type.get() f64 = lambda: F64Type.get() bf16 = lambda: BF16Type.get() From 58c4284486aad132b5f0853b5d0a527a1f942464 Mon Sep 17 00:00:00 2001 From: annuasd <97934297+annuasd@users.noreply.github.com> Date: Wed, 20 Nov 2024 05:24:39 +0800 Subject: [PATCH 794/915] [mlir][Bindings] Fix missing return value of functions and incorrect type hint in pyi. (#116731) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The zero points of UniformQuantizedPerAxisType should be List[int]. And there are two methods missing return value. Co-authored-by: 牛奕博 --- mlir/lib/Bindings/Python/DialectQuant.cpp | 2 ++ mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index af9cdc7bd..9a871f2c1 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -250,6 +250,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { double scale = mlirUniformQuantizedPerAxisTypeGetScale(type, i); scales.push_back(scale); } + return scales; }, "The scales designate the difference between the real values " "corresponding to consecutive quantized values differing by 1. The ith " @@ -265,6 +266,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { mlirUniformQuantizedPerAxisTypeGetZeroPoint(type, i); zeroPoints.push_back(zeroPoint); } + return zeroPoints; }, "the storage values corresponding to the real value 0 in the affine " "equation. The ith zero point corresponds to the ith slice in the " diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi index a10bc693b..47168d49c 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -101,7 +101,7 @@ class UniformQuantizedPerAxisType(QuantizedType): def scales(self) -> list[float]: ... @property - def zero_points(self) -> list[float]: ... + def zero_points(self) -> list[int]: ... @property def quantized_dimension(self) -> int: ... From 3c5f4caaca83aa581b5f9c956a10f9c61fd4f17e Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Wed, 20 Nov 2024 00:00:57 +0000 Subject: [PATCH 795/915] [mlir,python] Expose replaceAllUsesExcept to Python bindings (#115850) Problem originally described in [the forums here](https://discourse.llvm.org/t/mlir-python-expose-replaceallusesexcept/83068/1). Using the MLIR Python bindings, the method [`replaceAllUsesWith`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#ac56b0fdb6246bcf7fa1805ba0eb71aa2) for `Value` is exposed, e.g., ```python orig_value.replace_all_uses_with( new_value ) ``` However, in my use-case I am separating a block into multiple blocks, so thus want to exclude certain Operations from having their Values replaced (since I want them to diverge). Within Value, we have [`replaceAllUsesExcept`](https://mlir.llvm.org/doxygen/classmlir_1_1Value.html#a9ec8d5c61f8a6aada4062f609372cce4), where we can pass the Operations which should be skipped. This is not currently exposed in the Python bindings: this PR fixes this. Adds `replace_all_uses_except`, which works with individual Operations, and lists of Operations. --- mlir/include/mlir-c/IR.h | 9 +++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 29 +++++++++++++++++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 15 +++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index b8a6f08b1..0a515bbea 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -956,6 +956,15 @@ MLIR_CAPI_EXPORTED MlirOpOperand mlirValueGetFirstUse(MlirValue value); MLIR_CAPI_EXPORTED void mlirValueReplaceAllUsesOfWith(MlirValue of, MlirValue with); +/// Replace all uses of 'of' value with 'with' value, updating anything in the +/// IR that uses 'of' to use 'with' instead, except if the user is listed in +/// 'exceptions'. The 'exceptions' parameter is an array of MlirOperation +/// pointers with a length of 'numExceptions'. +MLIR_CAPI_EXPORTED void +mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, + intptr_t numExceptions, + MlirOperation *exceptions); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3562ff382..3e96f8c60 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -178,6 +178,12 @@ static const char kValueReplaceAllUsesWithDocstring[] = the IR that uses 'self' to use the other value instead. )"; +static const char kValueReplaceAllUsesExceptDocstring[] = + R"("Replace all uses of this value with the 'with' value, except for those +in 'exceptions'. 'exceptions' can be either a single operation or a list of +operations. +)"; + //------------------------------------------------------------------------------ // Utilities. //------------------------------------------------------------------------------ @@ -3718,6 +3724,29 @@ void mlir::python::populateIRCore(py::module &m) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, kValueReplaceAllUsesWithDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, PyOperation &exception) { + MlirOperation exceptedUser = exception.get(); + mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); + }, + py::arg("with"), py::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) + .def( + "replace_all_uses_except", + [](MlirValue self, MlirValue with, py::list exceptions) { + // Convert Python list to a SmallVector of MlirOperations + llvm::SmallVector exceptionOps; + for (py::handle exception : exceptions) { + exceptionOps.push_back(exception.cast().get()); + } + + mlirValueReplaceAllUsesExcept( + self, with, static_cast(exceptionOps.size()), + exceptionOps.data()); + }, + py::arg("with"), py::arg("exceptions"), + kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) { return self.maybeDownCast(); }); PyBlockArgument::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e7e6b11c8..24dc88540 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -28,6 +28,7 @@ #include "mlir/IR/Visitors.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Parser/Parser.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/ThreadPool.h" #include @@ -1009,6 +1010,20 @@ void mlirValueReplaceAllUsesOfWith(MlirValue oldValue, MlirValue newValue) { unwrap(oldValue).replaceAllUsesWith(unwrap(newValue)); } +void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, + intptr_t numExceptions, + MlirOperation *exceptions) { + Value oldValueCpp = unwrap(oldValue); + Value newValueCpp = unwrap(newValue); + + llvm::SmallPtrSet exceptionSet; + for (intptr_t i = 0; i < numExceptions; ++i) { + exceptionSet.insert(unwrap(exceptions[i])); + } + + oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// From 1df490eea1624c097950fcbe837190d267feed82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Sat, 23 Nov 2024 12:48:59 +0100 Subject: [PATCH 796/915] [mlir][python] Update minimal version of pybind11 to 2.10. (#117314) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR updates the minimal required version of pybind11 from 2.9.0 to 2.10.0. New new version is almost 2.5 years old, which is half a year less than the previous version. This change is necessary to support the changes introduced in #115307, which does not compile with pybind11 v.2.9. Signed-off-by: Ingo Müller --- mlir/python/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index eeaac2746..272d06683 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ numpy>=1.19.5, <=2.1.2 -pybind11>=2.9.0, <=2.13.6 +pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 From b0abb0568bb446b678cd4ec77873ebca0b009bee Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Sat, 23 Nov 2024 20:17:25 +0100 Subject: [PATCH 797/915] [MLIR][Python] Add the `--mlir-print-ir-tree-dir` to the C and Python API (#117339) --- mlir/include/mlir-c/Pass.h | 5 ++++- mlir/lib/Bindings/Python/Pass.cpp | 11 +++++++++-- mlir/lib/CAPI/IR/Pass.cpp | 18 +++++++++++++----- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 2218ec0f4..6019071cf 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -75,10 +75,13 @@ MLIR_CAPI_EXPORTED MlirLogicalResult mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); /// Enable IR printing. +/// The treePrintingPath argument is an optional path to a directory +/// where the dumps will be produced. If it isn't provided then dumps +/// are produced to stderr. MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure); + bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 1d0e5ce21..e8d28abe6 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -76,14 +76,21 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "enable_ir_printing", [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, - bool printAfterFailure) { + bool printAfterFailure, + std::optional optionalTreePrintingPath) { + std::string treePrintingPath = ""; + if (optionalTreePrintingPath.has_value()) + treePrintingPath = optionalTreePrintingPath.value(); mlirPassManagerEnableIRPrinting( passManager.get(), printBeforeAll, printAfterAll, - printModuleScope, printAfterChange, printAfterFailure); + printModuleScope, printAfterChange, printAfterFailure, + mlirStringRefCreate(treePrintingPath.data(), + treePrintingPath.size())); }, "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, + "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index a6c9fbd08..01151eafe 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -48,17 +48,25 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure) { + bool printAfterOnlyOnFailure, + MlirStringRef treePrintingPath) { auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { return printBeforeAll; }; auto shouldPrintAfterPass = [printAfterAll](Pass *, Operation *) { return printAfterAll; }; - return unwrap(passManager) - ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, - printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure); + if (unwrap(treePrintingPath).empty()) + return unwrap(passManager) + ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure); + + unwrap(passManager) + ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, + printModuleScope, printAfterOnlyOnChange, + printAfterOnlyOnFailure, + unwrap(treePrintingPath)); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { From d18af4f24022f7bcc4144cbd09b7e29fec2a0cca Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Mon, 25 Nov 2024 08:16:00 +0000 Subject: [PATCH 798/915] [mlir] Adjust code flagged by ClangTidyPerformance (NFC). We can allocate the size of the vector in advance. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 417c66b91..cc9532f4e 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1102,11 +1102,11 @@ class PyDenseElementsAttribute unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; unpackedBooleans = equalFunc(unpackedBooleans, 1); - std::vector shape; MlirType shapedType = mlirAttributeGetType(*this); intptr_t rank = mlirShapedTypeGetRank(shapedType); + std::vector shape(rank); for (intptr_t i = 0; i < rank; ++i) { - shape.push_back(mlirShapedTypeGetDimSize(shapedType, i)); + shape[i] = mlirShapedTypeGetDimSize(shapedType, i); } unpackedBooleans = reshapeFunc(unpackedBooleans, shape); From 9591f12738b465de011834d934e9d525a8b11fc4 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Mon, 25 Nov 2024 15:39:55 -0800 Subject: [PATCH 799/915] [mlir][py] Enable disabling loading all registered (#117643) There is a pending todo about always eagerly loading or not. Make this behavior optional and give the control to the user in a backwards compatible manner. This is made optional as there were arguments for both forms, kept it in form that is backwards compatible. --- mlir/python/mlir/_mlir_libs/__init__.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index 98dbbc6ad..c5cb22c6d 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -80,9 +80,16 @@ def _site_initialize(): logger = logging.getLogger(__name__) post_init_hooks = [] disable_multithreading = False + # This flag disables eagerly loading all dialects. Eagerly loading is often + # not the desired behavior (see + # https://github.com/llvm/llvm-project/issues/56037), and the logic is that + # if any module has this attribute set, then we don't load all (e.g., it's + # being used in a solution where the loading is controlled). + disable_load_all_available_dialects = False def process_initializer_module(module_name): nonlocal disable_multithreading + nonlocal disable_load_all_available_dialects try: m = importlib.import_module(f".{module_name}", __name__) except ModuleNotFoundError: @@ -107,6 +114,8 @@ def process_initializer_module(module_name): if bool(m.disable_multithreading): logger.debug("Disabling multi-threading for context") disable_multithreading = True + if hasattr(m, "disable_load_all_available_dialects"): + disable_load_all_available_dialects = True return True # If _mlirRegisterEverything is built, then include it as an initializer @@ -130,10 +139,8 @@ def __init__(self, *args, **kwargs): hook(self) if not disable_multithreading: self.enable_multithreading(True) - # TODO: There is some debate about whether we should eagerly load - # all dialects. It is being done here in order to preserve existing - # behavior. See: https://github.com/llvm/llvm-project/issues/56037 - self.load_all_available_dialects() + if not disable_load_all_available_dialects: + self.load_all_available_dialects() if init_module: logger.debug( "Registering translations from initializer %r", init_module From 49e93cb1fd3c80759925213cbfae3cce8d76b6d5 Mon Sep 17 00:00:00 2001 From: Perry Gibson Date: Mon, 2 Dec 2024 16:55:51 +0000 Subject: [PATCH 800/915] [mlir,python] Fix case when `FuncOp.arg_attrs` is not set (#117188) FuncOps can have `arg_attrs`, an array of dictionary attributes associated with their arguments. E.g., ```mlir func.func @main(%arg0: tensor<8xf32> {test.attr_name = "value"}, %arg1: tensor<8x16xf32>) ``` These are exposed via the MLIR Python bindings with `my_funcop.arg_attrs`. In this case, it would return `[{test.attr_name = "value"}, {}]`, i.e., `%arg1` has an empty `DictAttr`. However, if I try and access this property from a FuncOp with an empty `arg_attrs`, e.g., ```mlir func.func @main(%arg0: tensor<8xf32>, %arg1: tensor<8x16xf32>) ``` This raises the error: ```python return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^ KeyError: 'attempt to access a non-existent attribute' ``` This PR fixes this by returning the expected `[{}, {}]`. --- mlir/python/mlir/dialects/func.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/python/mlir/dialects/func.py b/mlir/python/mlir/dialects/func.py index 24fdcbcd8..1898fc156 100644 --- a/mlir/python/mlir/dialects/func.py +++ b/mlir/python/mlir/dialects/func.py @@ -105,6 +105,8 @@ def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None): @property def arg_attrs(self): + if ARGUMENT_ATTRIBUTE_NAME not in self.attributes: + return ArrayAttr.get([DictAttr.get({}) for _ in self.type.inputs]) return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME]) @arg_attrs.setter From 22b5045af2998cdd20545954fee34739e92740c5 Mon Sep 17 00:00:00 2001 From: Henrich Lauko Date: Tue, 3 Dec 2024 16:16:16 +0100 Subject: [PATCH 801/915] [mlir][llvm] Align linkage enum order with LLVM (NFC) (#118484) This change doesn't introduce any functional differences but aligns the implementation more closely with LLVM's representation. Previously, the code generated a lookup table to map MLIR enums to LLVM enums due to the lack of one-to-one correspondence. With this refactoring, the generated code now casts directly from one enum to another. --- mlir/include/mlir-c/Dialect/LLVM.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 0e6434073..ed9b23c34 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -175,17 +175,17 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirLLVMComdatAttrGet(MlirContext ctx, MlirLLVMComdat comdat); enum MlirLLVMLinkage { - MlirLLVMLinkagePrivate = 0, - MlirLLVMLinkageInternal = 1, - MlirLLVMLinkageAvailableExternally = 2, - MlirLLVMLinkageLinkonce = 3, + MlirLLVMLinkageExternal = 0, + MlirLLVMLinkageAvailableExternally = 1, + MlirLLVMLinkageLinkonce = 2, + MlirLLVMLinkageLinkonceODR = 3, MlirLLVMLinkageWeak = 4, - MlirLLVMLinkageCommon = 5, + MlirLLVMLinkageWeakODR = 5, MlirLLVMLinkageAppending = 6, - MlirLLVMLinkageExternWeak = 7, - MlirLLVMLinkageLinkonceODR = 8, - MlirLLVMLinkageWeakODR = 9, - MlirLLVMLinkageExternal = 10, + MlirLLVMLinkageInternal = 7, + MlirLLVMLinkagePrivate = 8, + MlirLLVMLinkageExternWeak = 9, + MlirLLVMLinkageCommon = 10, }; typedef enum MlirLLVMLinkage MlirLLVMLinkage; From fdf9d22b50347370ad54eb8f9a800a51bd7b91a8 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 3 Dec 2024 12:13:34 -0500 Subject: [PATCH 802/915] [mlir python] Add nanobind support for standalone dialects. (#117922) This PR allows out-of-tree dialects to write Python dialect modules using nanobind instead of pybind11. It may make sense to migrate in-tree dialects and some of the ODS Python infrastructure to nanobind, but that is a topic for a future change. This PR makes the following changes: * adds nanobind to the CMake and Bazel build systems. We also add robin_map to the Bazel build, which is a dependency of nanobind. * adds a PYTHON_BINDING_LIBRARY option to various CMake functions, such as declare_mlir_python_extension, allowing users to select a Python binding library. * creates a fork of mlir/include/mlir/Bindings/Python/PybindAdaptors.h named NanobindAdaptors.h. This plays the same role, using nanobind instead of pybind11. * splits CollectDiagnosticsToStringScope out of PybindAdaptors.h and into a new header mlir/include/mlir/Bindings/Python/Diagnostics.h, since it is code that is no way related to pybind11 or for that matter, Python. * changed the standalone Python extension example to have both pybind11 and nanobind variants. * changed mlir/python/mlir/dialects/python_test.py to have both pybind11 and nanobind variants. Notes: * A slightly unfortunate thing that I needed to do in the CMake integration was to use FindPython in addition to FindPython3, since nanobind's CMake integration expects the Python_ names for variables. Perhaps there's a better way to do this. --- .../mlir/Bindings/Python/Diagnostics.h | 59 ++ .../mlir/Bindings/Python/NanobindAdaptors.h | 671 ++++++++++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 43 +- mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 +- .../Bindings/Python/TransformInterpreter.cpp | 7 +- mlir/python/CMakeLists.txt | 23 +- mlir/python/mlir/dialects/python_test.py | 17 +- mlir/python/requirements.txt | 1 + 8 files changed, 770 insertions(+), 55 deletions(-) create mode 100644 mlir/include/mlir/Bindings/Python/Diagnostics.h create mode 100644 mlir/include/mlir/Bindings/Python/NanobindAdaptors.h diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h new file mode 100644 index 000000000..ea80e14dd --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h @@ -0,0 +1,59 @@ +//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H +#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H + +#include +#include + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace python { + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h new file mode 100644 index 000000000..5e01cebcb --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -0,0 +1,671 @@ +//===- NanobindAdaptors.h - Interop with MLIR APIs via nanobind -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types, using nanobind. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H + +#include +#include + +#include + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "llvm/ADT/Twine.h" + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace nanobind { +namespace detail { + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). +static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return nanobind::borrow(apiObject); + if (!nanobind::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + std::string repr = nanobind::cast(nanobind::repr(apiObject)); + throw nanobind::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").") + .str() + .c_str()); + } + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + +// Note: Currently all of the following support cast from nanobind::object to +// the Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(value)) { + return false; + } + return !mlirAffineMapIsNull(value); + } + static handle from_cpp(MlirAffineMap v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAffineMapToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAttribute(capsule.ptr()); + return !mlirAttributeIsNull(value); + } + static handle from_cpp(MlirAttribute v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAttributeToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts object -> MlirBlock. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToBlock(capsule.ptr()); + return !mlirBlockIsNull(value); + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirContext, const_name("MlirContext")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Context") + .attr("current"); + } + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule.ptr()); + return !mlirContextIsNull(value); + } +}; + +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle from_cpp(MlirDialectRegistry v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = nanobind::steal( + mlirPythonDialectRegistryToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirLocation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + if (src.is_none()) { + // Gets the current thread-bound context. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToLocation(capsule.ptr()); + return !mlirLocationIsNull(value); + } + static handle from_cpp(MlirLocation v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonLocationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirModule, const_name("MlirModule")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToModule(capsule.ptr()); + return !mlirModuleIsNull(value); + } + static handle from_cpp(MlirModule v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonModuleToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirFrozenRewritePatternSet, + const_name("MlirFrozenRewritePatternSet")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) { + nanobind::object capsule = nanobind::steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToOperation(capsule.ptr()); + return !mlirOperationIsNull(value); + } + static handle from_cpp(MlirOperation v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonOperationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirValue. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirValue, const_name("MlirValue")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToValue(capsule.ptr()); + return !mlirValueIsNull(value); + } + static handle from_cpp(MlirValue v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonValueToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + return !mlirPassManagerIsNull(value); + } +}; + +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle from_cpp(MlirTypeID v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonTypeIDToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirType, const_name("MlirType")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + return !mlirTypeIsNull(value); + } + static handle from_cpp(MlirType t, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonTypeToCapsule(t)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +} // namespace detail +} // namespace nanobind + +namespace mlir { +namespace python { +namespace nanobind_adaptors { + +/// Provides a facility like nanobind::class_ for defining a new class in a +/// scope, but this allows extension of an arbitrary Python class, defining +/// methods on it is a similar way. Classes defined in this way are very similar +/// to if defined in Python in the usual way but use nanobind machinery to +/// do it. These are not "real" nanobind classes but pure Python classes +/// with no relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in nanobind. +class pure_subclass { +public: + pure_subclass(nanobind::handle scope, const char *derivedClassName, + const nanobind::object &superClass) { + nanobind::object pyType = + nanobind::borrow((PyObject *)&PyType_Type); + nanobind::object metaclass = pyType(superClass); + nanobind::dict attributes; + + thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass), + attributes); + scope.attr(derivedClassName) = thisClass; + } + + template + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + auto builtinProperty = + nanobind::borrow((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = + nanobind::borrow(PyClassMethod_New(cf.ptr())); + return *this; + } + + nanobind::object get_class() const { return thisClass; } + +protected: + nanobind::object superClass; + nanobind::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName]( + nanobind::object cls, nanobind::object otherAttribute) { + MlirAttribute rawAttribute = + nanobind::cast(otherAttribute); + if (!isaFunction(rawAttribute)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherAttribute)); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + nanobind::arg("other_attribute")); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName](nanobind::object cls, + nanobind::object otherType) { + MlirType rawType = nanobind::cast(otherType); + if (!isaFunction(rawType)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherType)); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherType); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + nanobind::arg("other_type")); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirType) { + return thisClass(mlirType); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value")) {} + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureValueName](nanobind::object cls, + nanobind::object otherValue) { + MlirValue rawValue = nanobind::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherValue)); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + nanobind::arg("other_value")); + } +}; + +} // namespace nanobind_adaptors + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index df4b9bf71..c8233355d 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -1,4 +1,4 @@ -//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===// +//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,10 @@ // //===----------------------------------------------------------------------===// // This file contains adaptors for clients of the core MLIR Python APIs to -// interop via MLIR CAPI types. The facilities here do not depend on -// implementation details of the MLIR Python API and do not introduce C++-level -// dependencies with it (requiring only Python and CAPI-level dependencies). +// interop via MLIR CAPI types, using pybind11. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). // // It is encouraged to be used both in-tree and out-of-tree. For in-tree use // cases, it should be used for dialect implementations (versus relying on @@ -611,40 +612,6 @@ class mlir_value_subclass : public pure_subclass { } // namespace adaptors -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); - }; - MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; - mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 42a4c8c07..cccf1370b 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir-c/Dialect/LLVM.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f6b4532b1..0c8c0e0a9 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -10,14 +10,15 @@ // //===----------------------------------------------------------------------===// +#include +#include + #include "mlir-c/Dialect/Transform/Interpreter.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include - namespace py = pybind11; namespace { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 23187f256..e1b870b53 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -683,7 +683,9 @@ if(MLIR_INCLUDE_TESTS) MLIRPythonTestSources.Dialects.PythonTest ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonTestSources.Dialects - SOURCES dialects/python_test.py) + SOURCES + dialects/python_test.py + ) set(LLVM_TARGET_DEFINITIONS "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") mlir_tablegen( @@ -697,12 +699,25 @@ if(MLIR_INCLUDE_TESTS) ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest SOURCES "dialects/_python_test_ops_gen.py") - declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension - MODULE_NAME _mlirPythonTest + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11 + MODULE_NAME _mlirPythonTestPybind11 + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY pybind11 + SOURCES + PythonTestModulePybind11.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind + MODULE_NAME _mlirPythonTestNanobind ADD_TO_PARENT MLIRPythonTestSources.Dialects ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES - PythonTestModule.cpp + PythonTestModuleNanobind.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index b5baa80bc..9380896c8 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import ( - TestAttr, - TestType, - TestTensorValue, - TestIntegerRankedTensorType, -) -def register_python_test_dialect(registry): - from .._mlir_libs import _mlirPythonTest +def register_python_test_dialect(registry, use_nanobind): + if use_nanobind: + from .._mlir_libs import _mlirPythonTestNanobind - _mlirPythonTest.register_dialect(registry) + _mlirPythonTestNanobind.register_dialect(registry) + else: + from .._mlir_libs import _mlirPythonTestPybind11 + + _mlirPythonTestPybind11.register_dialect(registry) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 272d06683..ab8a91229 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ +nanobind>=2.0, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From 5f5b11f1076263a20341fac092dcc961ee846ebb Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 3 Dec 2024 11:26:33 -0600 Subject: [PATCH 803/915] Revert "[mlir python] Add nanobind support for standalone dialects." (#118517) Reverts llvm/llvm-project#117922 because deps aren't met on some of the post-commit build bots. --- .../mlir/Bindings/Python/Diagnostics.h | 59 -- .../mlir/Bindings/Python/NanobindAdaptors.h | 671 ------------------ .../mlir/Bindings/Python/PybindAdaptors.h | 43 +- mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 +- .../Bindings/Python/TransformInterpreter.cpp | 7 +- mlir/python/CMakeLists.txt | 23 +- mlir/python/mlir/dialects/python_test.py | 17 +- mlir/python/requirements.txt | 1 - 8 files changed, 55 insertions(+), 770 deletions(-) delete mode 100644 mlir/include/mlir/Bindings/Python/Diagnostics.h delete mode 100644 mlir/include/mlir/Bindings/Python/NanobindAdaptors.h diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h deleted file mode 100644 index ea80e14dd..000000000 --- a/mlir/include/mlir/Bindings/Python/Diagnostics.h +++ /dev/null @@ -1,59 +0,0 @@ -//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H -#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H - -#include -#include - -#include "mlir-c/Diagnostics.h" -#include "mlir-c/IR.h" -#include "llvm/ADT/StringRef.h" - -namespace mlir { -namespace python { - -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); - }; - MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; - mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h deleted file mode 100644 index 5e01cebcb..000000000 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ /dev/null @@ -1,671 +0,0 @@ -//===- NanobindAdaptors.h - Interop with MLIR APIs via nanobind -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// This file contains adaptors for clients of the core MLIR Python APIs to -// interop via MLIR CAPI types, using nanobind. The facilities here do not -// depend on implementation details of the MLIR Python API and do not introduce -// C++-level dependencies with it (requiring only Python and CAPI-level -// dependencies). -// -// It is encouraged to be used both in-tree and out-of-tree. For in-tree use -// cases, it should be used for dialect implementations (versus relying on -// Pybind-based internals of the core libraries). -//===----------------------------------------------------------------------===// - -#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H -#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H - -#include -#include - -#include - -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Diagnostics.h" -#include "mlir-c/IR.h" -#include "llvm/ADT/Twine.h" - -// Raw CAPI type casters need to be declared before use, so always include them -// first. -namespace nanobind { -namespace detail { - -/// Helper to convert a presumed MLIR API object to a capsule, accepting either -/// an explicit Capsule (which can happen when two C APIs are communicating -/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR -/// attribute (through which supported MLIR Python API objects export their -/// contained API pointer as a capsule). Throws a type error if the object is -/// neither. This is intended to be used from type casters, which are invoked -/// with a raw handle (unowned). The returned object's lifetime may not extend -/// beyond the apiObject handle without explicitly having its refcount increased -/// (i.e. on return). -static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { - if (PyCapsule_CheckExact(apiObject.ptr())) - return nanobind::borrow(apiObject); - if (!nanobind::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { - std::string repr = nanobind::cast(nanobind::repr(apiObject)); - throw nanobind::type_error( - (llvm::Twine("Expected an MLIR object (got ") + repr + ").") - .str() - .c_str()); - } - return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); -} - -// Note: Currently all of the following support cast from nanobind::object to -// the Mlir* C-API type, but only a few light-weight, context-bound ones -// implicitly cast the other way because the use case has not yet emerged and -// ownership is unclear. - -/// Casts object <-> MlirAffineMap. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToAffineMap(capsule.ptr()); - if (mlirAffineMapIsNull(value)) { - return false; - } - return !mlirAffineMapIsNull(value); - } - static handle from_cpp(MlirAffineMap v, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = - nanobind::steal(mlirPythonAffineMapToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("AffineMap") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - } -}; - -/// Casts object <-> MlirAttribute. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToAttribute(capsule.ptr()); - return !mlirAttributeIsNull(value); - } - static handle from_cpp(MlirAttribute v, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = - nanobind::steal(mlirPythonAttributeToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() - .release(); - } -}; - -/// Casts object -> MlirBlock. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToBlock(capsule.ptr()); - return !mlirBlockIsNull(value); - } -}; - -/// Casts object -> MlirContext. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirContext, const_name("MlirContext")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - if (src.is_none()) { - // Gets the current thread-bound context. - // TODO: This raises an error of "No current context" currently. - // Update the implementation to pretty-print the helpful error that the - // core implementations print in this case. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Context") - .attr("current"); - } - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToContext(capsule.ptr()); - return !mlirContextIsNull(value); - } -}; - -/// Casts object <-> MlirDialectRegistry. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); - return !mlirDialectRegistryIsNull(value); - } - static handle from_cpp(MlirDialectRegistry v, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = nanobind::steal( - mlirPythonDialectRegistryToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("DialectRegistry") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - } -}; - -/// Casts object <-> MlirLocation. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - if (src.is_none()) { - // Gets the current thread-bound context. - src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Location") - .attr("current"); - } - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToLocation(capsule.ptr()); - return !mlirLocationIsNull(value); - } - static handle from_cpp(MlirLocation v, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = - nanobind::steal(mlirPythonLocationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Location") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - } -}; - -/// Casts object <-> MlirModule. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirModule, const_name("MlirModule")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToModule(capsule.ptr()); - return !mlirModuleIsNull(value); - } - static handle from_cpp(MlirModule v, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = - nanobind::steal(mlirPythonModuleToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Module") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - }; -}; - -/// Casts object <-> MlirFrozenRewritePatternSet. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirFrozenRewritePatternSet, - const_name("MlirFrozenRewritePatternSet")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); - return value.ptr != nullptr; - } - static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) { - nanobind::object capsule = nanobind::steal( - mlirPythonFrozenRewritePatternSetToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) - .attr("FrozenRewritePatternSet") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - }; -}; - -/// Casts object <-> MlirOperation. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToOperation(capsule.ptr()); - return !mlirOperationIsNull(value); - } - static handle from_cpp(MlirOperation v, rv_policy, - cleanup_list *cleanup) noexcept { - if (v.ptr == nullptr) - return nanobind::none(); - nanobind::object capsule = - nanobind::steal(mlirPythonOperationToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Operation") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - }; -}; - -/// Casts object <-> MlirValue. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirValue, const_name("MlirValue")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToValue(capsule.ptr()); - return !mlirValueIsNull(value); - } - static handle from_cpp(MlirValue v, rv_policy, - cleanup_list *cleanup) noexcept { - if (v.ptr == nullptr) - return nanobind::none(); - nanobind::object capsule = - nanobind::steal(mlirPythonValueToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Value") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() - .release(); - }; -}; - -/// Casts object -> MlirPassManager. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToPassManager(capsule.ptr()); - return !mlirPassManagerIsNull(value); - } -}; - -/// Casts object <-> MlirTypeID. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToTypeID(capsule.ptr()); - return !mlirTypeIDIsNull(value); - } - static handle from_cpp(MlirTypeID v, rv_policy, - cleanup_list *cleanup) noexcept { - if (v.ptr == nullptr) - return nanobind::none(); - nanobind::object capsule = - nanobind::steal(mlirPythonTypeIDToCapsule(v)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("TypeID") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .release(); - }; -}; - -/// Casts object <-> MlirType. -template <> -struct type_caster { - NB_TYPE_CASTER(MlirType, const_name("MlirType")); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToType(capsule.ptr()); - return !mlirTypeIsNull(value); - } - static handle from_cpp(MlirType t, rv_policy, - cleanup_list *cleanup) noexcept { - nanobind::object capsule = - nanobind::steal(mlirPythonTypeToCapsule(t)); - return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Type") - .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) - .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() - .release(); - } -}; - -} // namespace detail -} // namespace nanobind - -namespace mlir { -namespace python { -namespace nanobind_adaptors { - -/// Provides a facility like nanobind::class_ for defining a new class in a -/// scope, but this allows extension of an arbitrary Python class, defining -/// methods on it is a similar way. Classes defined in this way are very similar -/// to if defined in Python in the usual way but use nanobind machinery to -/// do it. These are not "real" nanobind classes but pure Python classes -/// with no relation to a concrete C++ class. -/// -/// Derived from a discussion upstream: -/// https://github.com/pybind/pybind11/issues/1193 -/// (plus a fair amount of extra curricular poking) -/// TODO: If this proves useful, see about including it in nanobind. -class pure_subclass { -public: - pure_subclass(nanobind::handle scope, const char *derivedClassName, - const nanobind::object &superClass) { - nanobind::object pyType = - nanobind::borrow((PyObject *)&PyType_Type); - nanobind::object metaclass = pyType(superClass); - nanobind::dict attributes; - - thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass), - attributes); - scope.attr(derivedClassName) = thisClass; - } - - template - pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { - nanobind::object cf = nanobind::cpp_function( - std::forward(f), nanobind::name(name), nanobind::is_method(), - nanobind::scope(thisClass), extra...); - thisClass.attr(name) = cf; - return *this; - } - - template - pure_subclass &def_property_readonly(const char *name, Func &&f, - const Extra &...extra) { - nanobind::object cf = nanobind::cpp_function( - std::forward(f), nanobind::name(name), nanobind::is_method(), - nanobind::scope(thisClass), extra...); - auto builtinProperty = - nanobind::borrow((PyObject *)&PyProperty_Type); - thisClass.attr(name) = builtinProperty(cf); - return *this; - } - - template - pure_subclass &def_staticmethod(const char *name, Func &&f, - const Extra &...extra) { - static_assert(!std::is_member_function_pointer::value, - "def_staticmethod(...) called with a non-static member " - "function pointer"); - nanobind::object cf = nanobind::cpp_function( - std::forward(f), - nanobind::name(name), // nanobind::scope(thisClass), - extra...); - thisClass.attr(name) = cf; - return *this; - } - - template - pure_subclass &def_classmethod(const char *name, Func &&f, - const Extra &...extra) { - static_assert(!std::is_member_function_pointer::value, - "def_classmethod(...) called with a non-static member " - "function pointer"); - nanobind::object cf = nanobind::cpp_function( - std::forward(f), - nanobind::name(name), // nanobind::scope(thisClass), - extra...); - thisClass.attr(name) = - nanobind::borrow(PyClassMethod_New(cf.ptr())); - return *this; - } - - nanobind::object get_class() const { return thisClass; } - -protected: - nanobind::object superClass; - nanobind::object thisClass; -}; - -/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting -/// constructor and type checking methods. -class mlir_attribute_subclass : public pure_subclass { -public: - using IsAFunctionTy = bool (*)(MlirAttribute); - using GetTypeIDFunctionTy = MlirTypeID (*)(); - - /// Subclasses by looking up the super-class dynamically. - mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, - IsAFunctionTy isaFunction, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_attribute_subclass( - scope, attrClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Attribute"), - getTypeIDFunction) {} - - /// Subclasses with a provided mlir.ir.Attribute super-class. This must - /// be used if the subclass is being defined in the same extension module - /// as the mlir.ir class (otherwise, it will trigger a recursive - /// initialization). - mlir_attribute_subclass(nanobind::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - const nanobind::object &superCls, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : pure_subclass(scope, typeClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly - // call chain to parent `__init__` in nanobind due to its special handling - // for init functions that don't have a fully constructed self-reference, - // which makes it impossible to forward it to `__init__` of a superclass. - // Instead, provide a custom `__new__` and call that of a superclass, which - // eventually calls `__init__` of the superclass. Since attribute subclasses - // have no additional members, we can just return the instance thus created - // without amending it. - std::string captureTypeName( - typeClassName); // As string in case if typeClassName is not static. - nanobind::object newCf = nanobind::cpp_function( - [superCls, isaFunction, captureTypeName]( - nanobind::object cls, nanobind::object otherAttribute) { - MlirAttribute rawAttribute = - nanobind::cast(otherAttribute); - if (!isaFunction(rawAttribute)) { - auto origRepr = - nanobind::cast(nanobind::repr(otherAttribute)); - throw std::invalid_argument( - (llvm::Twine("Cannot cast attribute to ") + captureTypeName + - " (from " + origRepr + ")") - .str()); - } - nanobind::object self = superCls.attr("__new__")(cls, otherAttribute); - return self; - }, - nanobind::name("__new__"), nanobind::arg("cls"), - nanobind::arg("cast_from_attr")); - thisClass.attr("__new__") = newCf; - - // 'isinstance' method. - def_staticmethod( - "isinstance", - [isaFunction](MlirAttribute other) { return isaFunction(other); }, - nanobind::arg("other_attribute")); - def("__repr__", [superCls, captureTypeName](nanobind::object self) { - return nanobind::repr(superCls(self)) - .attr("replace")(superCls.attr("__name__"), captureTypeName); - }); - if (getTypeIDFunction) { - def_staticmethod("get_static_typeid", - [getTypeIDFunction]() { return getTypeIDFunction(); }); - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction())(nanobind::cpp_function( - [thisClass = thisClass](const nanobind::object &mlirAttribute) { - return thisClass(mlirAttribute); - })); - } - } -}; - -/// Creates a custom subclass of mlir.ir.Type, implementing a casting -/// constructor and type checking methods. -class mlir_type_subclass : public pure_subclass { -public: - using IsAFunctionTy = bool (*)(MlirType); - using GetTypeIDFunctionTy = MlirTypeID (*)(); - - /// Subclasses by looking up the super-class dynamically. - mlir_type_subclass(nanobind::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : mlir_type_subclass( - scope, typeClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Type"), - getTypeIDFunction) {} - - /// Subclasses with a provided mlir.ir.Type super-class. This must - /// be used if the subclass is being defined in the same extension module - /// as the mlir.ir class (otherwise, it will trigger a recursive - /// initialization). - mlir_type_subclass(nanobind::handle scope, const char *typeClassName, - IsAFunctionTy isaFunction, - const nanobind::object &superCls, - GetTypeIDFunctionTy getTypeIDFunction = nullptr) - : pure_subclass(scope, typeClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly - // call chain to parent `__init__` in nanobind due to its special handling - // for init functions that don't have a fully constructed self-reference, - // which makes it impossible to forward it to `__init__` of a superclass. - // Instead, provide a custom `__new__` and call that of a superclass, which - // eventually calls `__init__` of the superclass. Since attribute subclasses - // have no additional members, we can just return the instance thus created - // without amending it. - std::string captureTypeName( - typeClassName); // As string in case if typeClassName is not static. - nanobind::object newCf = nanobind::cpp_function( - [superCls, isaFunction, captureTypeName](nanobind::object cls, - nanobind::object otherType) { - MlirType rawType = nanobind::cast(otherType); - if (!isaFunction(rawType)) { - auto origRepr = - nanobind::cast(nanobind::repr(otherType)); - throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + - captureTypeName + " (from " + - origRepr + ")") - .str()); - } - nanobind::object self = superCls.attr("__new__")(cls, otherType); - return self; - }, - nanobind::name("__new__"), nanobind::arg("cls"), - nanobind::arg("cast_from_type")); - thisClass.attr("__new__") = newCf; - - // 'isinstance' method. - def_staticmethod( - "isinstance", - [isaFunction](MlirType other) { return isaFunction(other); }, - nanobind::arg("other_type")); - def("__repr__", [superCls, captureTypeName](nanobind::object self) { - return nanobind::repr(superCls(self)) - .attr("replace")(superCls.attr("__name__"), captureTypeName); - }); - if (getTypeIDFunction) { - // 'get_static_typeid' method. - // This is modeled as a static method instead of a static property because - // `def_property_readonly_static` is not available in `pure_subclass` and - // we do not want to introduce the complexity that pybind uses to - // implement it. - def_staticmethod("get_static_typeid", - [getTypeIDFunction]() { return getTypeIDFunction(); }); - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction())(nanobind::cpp_function( - [thisClass = thisClass](const nanobind::object &mlirType) { - return thisClass(mlirType); - })); - } - } -}; - -/// Creates a custom subclass of mlir.ir.Value, implementing a casting -/// constructor and type checking methods. -class mlir_value_subclass : public pure_subclass { -public: - using IsAFunctionTy = bool (*)(MlirValue); - - /// Subclasses by looking up the super-class dynamically. - mlir_value_subclass(nanobind::handle scope, const char *valueClassName, - IsAFunctionTy isaFunction) - : mlir_value_subclass( - scope, valueClassName, isaFunction, - nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("Value")) {} - - /// Subclasses with a provided mlir.ir.Value super-class. This must - /// be used if the subclass is being defined in the same extension module - /// as the mlir.ir class (otherwise, it will trigger a recursive - /// initialization). - mlir_value_subclass(nanobind::handle scope, const char *valueClassName, - IsAFunctionTy isaFunction, - const nanobind::object &superCls) - : pure_subclass(scope, valueClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly - // call chain to parent `__init__` in nanobind due to its special handling - // for init functions that don't have a fully constructed self-reference, - // which makes it impossible to forward it to `__init__` of a superclass. - // Instead, provide a custom `__new__` and call that of a superclass, which - // eventually calls `__init__` of the superclass. Since attribute subclasses - // have no additional members, we can just return the instance thus created - // without amending it. - std::string captureValueName( - valueClassName); // As string in case if valueClassName is not static. - nanobind::object newCf = nanobind::cpp_function( - [superCls, isaFunction, captureValueName](nanobind::object cls, - nanobind::object otherValue) { - MlirValue rawValue = nanobind::cast(otherValue); - if (!isaFunction(rawValue)) { - auto origRepr = - nanobind::cast(nanobind::repr(otherValue)); - throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + - captureValueName + " (from " + - origRepr + ")") - .str()); - } - nanobind::object self = superCls.attr("__new__")(cls, otherValue); - return self; - }, - nanobind::name("__new__"), nanobind::arg("cls"), - nanobind::arg("cast_from_value")); - thisClass.attr("__new__") = newCf; - - // 'isinstance' method. - def_staticmethod( - "isinstance", - [isaFunction](MlirValue other) { return isaFunction(other); }, - nanobind::arg("other_value")); - } -}; - -} // namespace nanobind_adaptors - -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); - }; - MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; - mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - -} // namespace python -} // namespace mlir - -#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index c8233355d..df4b9bf71 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -1,4 +1,4 @@ -//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// +//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,10 +6,9 @@ // //===----------------------------------------------------------------------===// // This file contains adaptors for clients of the core MLIR Python APIs to -// interop via MLIR CAPI types, using pybind11. The facilities here do not -// depend on implementation details of the MLIR Python API and do not introduce -// C++-level dependencies with it (requiring only Python and CAPI-level -// dependencies). +// interop via MLIR CAPI types. The facilities here do not depend on +// implementation details of the MLIR Python API and do not introduce C++-level +// dependencies with it (requiring only Python and CAPI-level dependencies). // // It is encouraged to be used both in-tree and out-of-tree. For in-tree use // cases, it should be used for dialect implementations (versus relying on @@ -612,6 +611,40 @@ class mlir_value_subclass : public pure_subclass { } // namespace adaptors +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index cccf1370b..42a4c8c07 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -6,13 +6,11 @@ // //===----------------------------------------------------------------------===// -#include - #include "mlir-c/Dialect/LLVM.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 0c8c0e0a9..f6b4532b1 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -10,15 +10,14 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "mlir-c/Dialect/Transform/Interpreter.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +#include +#include + namespace py = pybind11; namespace { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index e1b870b53..23187f256 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -683,9 +683,7 @@ if(MLIR_INCLUDE_TESTS) MLIRPythonTestSources.Dialects.PythonTest ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonTestSources.Dialects - SOURCES - dialects/python_test.py - ) + SOURCES dialects/python_test.py) set(LLVM_TARGET_DEFINITIONS "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") mlir_tablegen( @@ -699,25 +697,12 @@ if(MLIR_INCLUDE_TESTS) ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest SOURCES "dialects/_python_test_ops_gen.py") - declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11 - MODULE_NAME _mlirPythonTestPybind11 - ADD_TO_PARENT MLIRPythonTestSources.Dialects - ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" - PYTHON_BINDINGS_LIBRARY pybind11 - SOURCES - PythonTestModulePybind11.cpp - PRIVATE_LINK_LIBS - LLVMSupport - EMBED_CAPI_LINK_LIBS - MLIRCAPIPythonTestDialect - ) - declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind - MODULE_NAME _mlirPythonTestNanobind + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension + MODULE_NAME _mlirPythonTest ADD_TO_PARENT MLIRPythonTestSources.Dialects ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" - PYTHON_BINDINGS_LIBRARY nanobind SOURCES - PythonTestModuleNanobind.cpp + PythonTestModule.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index 9380896c8..b5baa80bc 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,14 +3,15 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * +from .._mlir_libs._mlirPythonTest import ( + TestAttr, + TestType, + TestTensorValue, + TestIntegerRankedTensorType, +) -def register_python_test_dialect(registry, use_nanobind): - if use_nanobind: - from .._mlir_libs import _mlirPythonTestNanobind +def register_python_test_dialect(registry): + from .._mlir_libs import _mlirPythonTest - _mlirPythonTestNanobind.register_dialect(registry) - else: - from .._mlir_libs import _mlirPythonTestPybind11 - - _mlirPythonTestPybind11.register_dialect(registry) + _mlirPythonTest.register_dialect(registry) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index ab8a91229..272d06683 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,3 @@ -nanobind>=2.0, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From d557b314abf522172a278e0f2815970befab03c3 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 5 Dec 2024 17:31:04 +0800 Subject: [PATCH 804/915] [MLIR][Python] enhance python ir printing with pringing flags (#117836) Close https://github.com/llvm/llvm-project/pull/65854 --- mlir/include/mlir-c/Pass.h | 3 ++- mlir/lib/Bindings/Python/Pass.cpp | 17 +++++++++++++++-- mlir/lib/CAPI/IR/Pass.cpp | 6 ++++-- .../mlir/_mlir_libs/_mlir/passmanager.pyi | 4 ++++ 4 files changed, 25 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 6019071cf..8fd8e9956 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -81,7 +81,8 @@ mlirPassManagerRunOnOp(MlirPassManager passManager, MlirOperation op); MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MlirPassManager passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterOnlyOnChange, - bool printAfterOnlyOnFailure, MlirStringRef treePrintingPath); + bool printAfterOnlyOnFailure, MlirOpPrintingFlags flags, + MlirStringRef treePrintingPath); /// Enable / disable verify-each. MLIR_CAPI_EXPORTED void diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e8d28abe6..e991deaae 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -76,20 +76,33 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "enable_ir_printing", [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, - bool printAfterFailure, + bool printAfterFailure, std::optional largeElementsLimit, + bool enableDebugInfo, bool printGenericOpForm, std::optional optionalTreePrintingPath) { + MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); + if (largeElementsLimit) + mlirOpPrintingFlagsElideLargeElementsAttrs(flags, + *largeElementsLimit); + if (enableDebugInfo) + mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, + /*prettyForm=*/false); + if (printGenericOpForm) + mlirOpPrintingFlagsPrintGenericOpForm(flags); std::string treePrintingPath = ""; if (optionalTreePrintingPath.has_value()) treePrintingPath = optionalTreePrintingPath.value(); mlirPassManagerEnableIRPrinting( passManager.get(), printBeforeAll, printAfterAll, - printModuleScope, printAfterChange, printAfterFailure, + printModuleScope, printAfterChange, printAfterFailure, flags, mlirStringRefCreate(treePrintingPath.data(), treePrintingPath.size())); + mlirOpPrintingFlagsDestroy(flags); }, "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, + "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, + "print_generic_op_form"_a = false, "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 01151eafe..883b7e8bb 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -49,6 +49,7 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, bool printModuleScope, bool printAfterOnlyOnChange, bool printAfterOnlyOnFailure, + MlirOpPrintingFlags flags, MlirStringRef treePrintingPath) { auto shouldPrintBeforePass = [printBeforeAll](Pass *, Operation *) { return printBeforeAll; @@ -60,13 +61,14 @@ void mlirPassManagerEnableIRPrinting(MlirPassManager passManager, return unwrap(passManager) ->enableIRPrinting(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, - printAfterOnlyOnFailure); + printAfterOnlyOnFailure, /*out=*/llvm::errs(), + *unwrap(flags)); unwrap(passManager) ->enableIRPrintingToFileTree(shouldPrintBeforePass, shouldPrintAfterPass, printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, - unwrap(treePrintingPath)); + unwrap(treePrintingPath), *unwrap(flags)); } void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 229979ae3..0d2eaffe1 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -22,6 +22,10 @@ class PassManager: print_module_scope: bool = False, print_after_change: bool = False, print_after_failure: bool = False, + large_elements_limit: int | None = None, + enable_debug_info: bool = False, + print_generic_op_form: bool = False, + tree_printing_dir_path: str | None = None, ) -> None: ... def enable_verifier(self, enable: bool) -> None: ... @staticmethod From 9ee3f7d9b7c523b422a0efac5014ca24b591c6a7 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 9 Dec 2024 16:37:43 -0500 Subject: [PATCH 805/915] Revert "Revert "[mlir python] Add nanobind support (#119232) Reverts revert #118517 after (hopefully) fixing builders (https://github.com/llvm/llvm-zorg/pull/328, https://github.com/llvm/llvm-zorg/pull/327) This reverts commit 5f5b11f1076263a20341fac092dcc961ee846ebb. --- .../mlir/Bindings/Python/Diagnostics.h | 59 ++ .../mlir/Bindings/Python/NanobindAdaptors.h | 671 ++++++++++++++++++ .../mlir/Bindings/Python/PybindAdaptors.h | 43 +- mlir/lib/Bindings/Python/DialectLLVM.cpp | 4 +- .../Bindings/Python/TransformInterpreter.cpp | 7 +- mlir/python/CMakeLists.txt | 23 +- mlir/python/mlir/dialects/python_test.py | 17 +- mlir/python/requirements.txt | 1 + 8 files changed, 770 insertions(+), 55 deletions(-) create mode 100644 mlir/include/mlir/Bindings/Python/Diagnostics.h create mode 100644 mlir/include/mlir/Bindings/Python/NanobindAdaptors.h diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h new file mode 100644 index 000000000..ea80e14dd --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h @@ -0,0 +1,59 @@ +//===- Diagnostics.h - Helpers for diagnostics in Python bindings ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H +#define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H + +#include +#include + +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "llvm/ADT/StringRef.h" + +namespace mlir { +namespace python { + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h new file mode 100644 index 000000000..5e01cebcb --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -0,0 +1,671 @@ +//===- NanobindAdaptors.h - Interop with MLIR APIs via nanobind -----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file contains adaptors for clients of the core MLIR Python APIs to +// interop via MLIR CAPI types, using nanobind. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). +// +// It is encouraged to be used both in-tree and out-of-tree. For in-tree use +// cases, it should be used for dialect implementations (versus relying on +// Pybind-based internals of the core libraries). +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H +#define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H + +#include +#include + +#include + +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Diagnostics.h" +#include "mlir-c/IR.h" +#include "llvm/ADT/Twine.h" + +// Raw CAPI type casters need to be declared before use, so always include them +// first. +namespace nanobind { +namespace detail { + +/// Helper to convert a presumed MLIR API object to a capsule, accepting either +/// an explicit Capsule (which can happen when two C APIs are communicating +/// directly via Python) or indirectly by querying the MLIR_PYTHON_CAPI_PTR_ATTR +/// attribute (through which supported MLIR Python API objects export their +/// contained API pointer as a capsule). Throws a type error if the object is +/// neither. This is intended to be used from type casters, which are invoked +/// with a raw handle (unowned). The returned object's lifetime may not extend +/// beyond the apiObject handle without explicitly having its refcount increased +/// (i.e. on return). +static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { + if (PyCapsule_CheckExact(apiObject.ptr())) + return nanobind::borrow(apiObject); + if (!nanobind::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + std::string repr = nanobind::cast(nanobind::repr(apiObject)); + throw nanobind::type_error( + (llvm::Twine("Expected an MLIR object (got ") + repr + ").") + .str() + .c_str()); + } + return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); +} + +// Note: Currently all of the following support cast from nanobind::object to +// the Mlir* C-API type, but only a few light-weight, context-bound ones +// implicitly cast the other way because the use case has not yet emerged and +// ownership is unclear. + +/// Casts object <-> MlirAffineMap. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAffineMap(capsule.ptr()); + if (mlirAffineMapIsNull(value)) { + return false; + } + return !mlirAffineMapIsNull(value); + } + static handle from_cpp(MlirAffineMap v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAffineMapToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("AffineMap") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirAttribute. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToAttribute(capsule.ptr()); + return !mlirAttributeIsNull(value); + } + static handle from_cpp(MlirAttribute v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonAttributeToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +/// Casts object -> MlirBlock. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToBlock(capsule.ptr()); + return !mlirBlockIsNull(value); + } +}; + +/// Casts object -> MlirContext. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirContext, const_name("MlirContext")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + if (src.is_none()) { + // Gets the current thread-bound context. + // TODO: This raises an error of "No current context" currently. + // Update the implementation to pretty-print the helpful error that the + // core implementations print in this case. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Context") + .attr("current"); + } + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule.ptr()); + return !mlirContextIsNull(value); + } +}; + +/// Casts object <-> MlirDialectRegistry. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + return !mlirDialectRegistryIsNull(value); + } + static handle from_cpp(MlirDialectRegistry v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = nanobind::steal( + mlirPythonDialectRegistryToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("DialectRegistry") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirLocation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + if (src.is_none()) { + // Gets the current thread-bound context. + src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr("current"); + } + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToLocation(capsule.ptr()); + return !mlirLocationIsNull(value); + } + static handle from_cpp(MlirLocation v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonLocationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Location") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + } +}; + +/// Casts object <-> MlirModule. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirModule, const_name("MlirModule")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToModule(capsule.ptr()); + return !mlirModuleIsNull(value); + } + static handle from_cpp(MlirModule v, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonModuleToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Module") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirFrozenRewritePatternSet. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirFrozenRewritePatternSet, + const_name("MlirFrozenRewritePatternSet")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + return value.ptr != nullptr; + } + static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) { + nanobind::object capsule = nanobind::steal( + mlirPythonFrozenRewritePatternSetToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) + .attr("FrozenRewritePatternSet") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirOperation. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToOperation(capsule.ptr()); + return !mlirOperationIsNull(value); + } + static handle from_cpp(MlirOperation v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonOperationToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Operation") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirValue. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirValue, const_name("MlirValue")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToValue(capsule.ptr()); + return !mlirValueIsNull(value); + } + static handle from_cpp(MlirValue v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonValueToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + }; +}; + +/// Casts object -> MlirPassManager. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToPassManager(capsule.ptr()); + return !mlirPassManagerIsNull(value); + } +}; + +/// Casts object <-> MlirTypeID. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToTypeID(capsule.ptr()); + return !mlirTypeIDIsNull(value); + } + static handle from_cpp(MlirTypeID v, rv_policy, + cleanup_list *cleanup) noexcept { + if (v.ptr == nullptr) + return nanobind::none(); + nanobind::object capsule = + nanobind::steal(mlirPythonTypeIDToCapsule(v)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("TypeID") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .release(); + }; +}; + +/// Casts object <-> MlirType. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirType, const_name("MlirType")); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + nanobind::object capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToType(capsule.ptr()); + return !mlirTypeIsNull(value); + } + static handle from_cpp(MlirType t, rv_policy, + cleanup_list *cleanup) noexcept { + nanobind::object capsule = + nanobind::steal(mlirPythonTypeToCapsule(t)); + return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type") + .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() + .release(); + } +}; + +} // namespace detail +} // namespace nanobind + +namespace mlir { +namespace python { +namespace nanobind_adaptors { + +/// Provides a facility like nanobind::class_ for defining a new class in a +/// scope, but this allows extension of an arbitrary Python class, defining +/// methods on it is a similar way. Classes defined in this way are very similar +/// to if defined in Python in the usual way but use nanobind machinery to +/// do it. These are not "real" nanobind classes but pure Python classes +/// with no relation to a concrete C++ class. +/// +/// Derived from a discussion upstream: +/// https://github.com/pybind/pybind11/issues/1193 +/// (plus a fair amount of extra curricular poking) +/// TODO: If this proves useful, see about including it in nanobind. +class pure_subclass { +public: + pure_subclass(nanobind::handle scope, const char *derivedClassName, + const nanobind::object &superClass) { + nanobind::object pyType = + nanobind::borrow((PyObject *)&PyType_Type); + nanobind::object metaclass = pyType(superClass); + nanobind::dict attributes; + + thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass), + attributes); + scope.attr(derivedClassName) = thisClass; + } + + template + pure_subclass &def(const char *name, Func &&f, const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_property_readonly(const char *name, Func &&f, + const Extra &...extra) { + nanobind::object cf = nanobind::cpp_function( + std::forward(f), nanobind::name(name), nanobind::is_method(), + nanobind::scope(thisClass), extra...); + auto builtinProperty = + nanobind::borrow((PyObject *)&PyProperty_Type); + thisClass.attr(name) = builtinProperty(cf); + return *this; + } + + template + pure_subclass &def_staticmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_staticmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = cf; + return *this; + } + + template + pure_subclass &def_classmethod(const char *name, Func &&f, + const Extra &...extra) { + static_assert(!std::is_member_function_pointer::value, + "def_classmethod(...) called with a non-static member " + "function pointer"); + nanobind::object cf = nanobind::cpp_function( + std::forward(f), + nanobind::name(name), // nanobind::scope(thisClass), + extra...); + thisClass.attr(name) = + nanobind::borrow(PyClassMethod_New(cf.ptr())); + return *this; + } + + nanobind::object get_class() const { return thisClass; } + +protected: + nanobind::object superClass; + nanobind::object thisClass; +}; + +/// Creates a custom subclass of mlir.ir.Attribute, implementing a casting +/// constructor and type checking methods. +class mlir_attribute_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirAttribute); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_attribute_subclass(nanobind::handle scope, const char *attrClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_attribute_subclass( + scope, attrClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Attribute"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Attribute super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_attribute_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName]( + nanobind::object cls, nanobind::object otherAttribute) { + MlirAttribute rawAttribute = + nanobind::cast(otherAttribute); + if (!isaFunction(rawAttribute)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherAttribute)); + throw std::invalid_argument( + (llvm::Twine("Cannot cast attribute to ") + captureTypeName + + " (from " + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherAttribute); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_attr")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirAttribute other) { return isaFunction(other); }, + nanobind::arg("other_attribute")); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirAttribute) { + return thisClass(mlirAttribute); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Type, implementing a casting +/// constructor and type checking methods. +class mlir_type_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirType); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + + /// Subclasses by looking up the super-class dynamically. + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : mlir_type_subclass( + scope, typeClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Type"), + getTypeIDFunction) {} + + /// Subclasses with a provided mlir.ir.Type super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_type_subclass(nanobind::handle scope, const char *typeClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls, + GetTypeIDFunctionTy getTypeIDFunction = nullptr) + : pure_subclass(scope, typeClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureTypeName( + typeClassName); // As string in case if typeClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureTypeName](nanobind::object cls, + nanobind::object otherType) { + MlirType rawType = nanobind::cast(otherType); + if (!isaFunction(rawType)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherType)); + throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + + captureTypeName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherType); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_type")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirType other) { return isaFunction(other); }, + nanobind::arg("other_type")); + def("__repr__", [superCls, captureTypeName](nanobind::object self) { + return nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName); + }); + if (getTypeIDFunction) { + // 'get_static_typeid' method. + // This is modeled as a static method instead of a static property because + // `def_property_readonly_static` is not available in `pure_subclass` and + // we do not want to introduce the complexity that pybind uses to + // implement it. + def_staticmethod("get_static_typeid", + [getTypeIDFunction]() { return getTypeIDFunction(); }); + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( + getTypeIDFunction())(nanobind::cpp_function( + [thisClass = thisClass](const nanobind::object &mlirType) { + return thisClass(mlirType); + })); + } + } +}; + +/// Creates a custom subclass of mlir.ir.Value, implementing a casting +/// constructor and type checking methods. +class mlir_value_subclass : public pure_subclass { +public: + using IsAFunctionTy = bool (*)(MlirValue); + + /// Subclasses by looking up the super-class dynamically. + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction) + : mlir_value_subclass( + scope, valueClassName, isaFunction, + nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr("Value")) {} + + /// Subclasses with a provided mlir.ir.Value super-class. This must + /// be used if the subclass is being defined in the same extension module + /// as the mlir.ir class (otherwise, it will trigger a recursive + /// initialization). + mlir_value_subclass(nanobind::handle scope, const char *valueClassName, + IsAFunctionTy isaFunction, + const nanobind::object &superCls) + : pure_subclass(scope, valueClassName, superCls) { + // Casting constructor. Note that it hard, if not impossible, to properly + // call chain to parent `__init__` in nanobind due to its special handling + // for init functions that don't have a fully constructed self-reference, + // which makes it impossible to forward it to `__init__` of a superclass. + // Instead, provide a custom `__new__` and call that of a superclass, which + // eventually calls `__init__` of the superclass. Since attribute subclasses + // have no additional members, we can just return the instance thus created + // without amending it. + std::string captureValueName( + valueClassName); // As string in case if valueClassName is not static. + nanobind::object newCf = nanobind::cpp_function( + [superCls, isaFunction, captureValueName](nanobind::object cls, + nanobind::object otherValue) { + MlirValue rawValue = nanobind::cast(otherValue); + if (!isaFunction(rawValue)) { + auto origRepr = + nanobind::cast(nanobind::repr(otherValue)); + throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + + captureValueName + " (from " + + origRepr + ")") + .str()); + } + nanobind::object self = superCls.attr("__new__")(cls, otherValue); + return self; + }, + nanobind::name("__new__"), nanobind::arg("cls"), + nanobind::arg("cast_from_value")); + thisClass.attr("__new__") = newCf; + + // 'isinstance' method. + def_staticmethod( + "isinstance", + [isaFunction](MlirValue other) { return isaFunction(other); }, + nanobind::arg("other_value")); + } +}; + +} // namespace nanobind_adaptors + +/// RAII scope intercepting all diagnostics into a string. The message must be +/// checked before this goes out of scope. +class CollectDiagnosticsToStringScope { +public: + explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { + handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, + /*deleteUserData=*/nullptr); + } + ~CollectDiagnosticsToStringScope() { + assert(errorMessage.empty() && "unchecked error message"); + mlirContextDetachDiagnosticHandler(context, handlerID); + } + + [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + +private: + static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { + auto printer = +[](MlirStringRef message, void *data) { + *static_cast(data) += + llvm::StringRef(message.data, message.length); + }; + MlirLocation loc = mlirDiagnosticGetLocation(diag); + *static_cast(data) += "at "; + mlirLocationPrint(loc, printer, data); + *static_cast(data) += ": "; + mlirDiagnosticPrint(diag, printer, data); + return mlirLogicalResultSuccess(); + } + + MlirContext context; + MlirDiagnosticHandlerID handlerID; + std::string errorMessage = ""; +}; + +} // namespace python +} // namespace mlir + +#endif // MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index df4b9bf71..c8233355d 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -1,4 +1,4 @@ -//===- PybindAdaptors.h - Adaptors for interop with MLIR APIs -------------===// +//===- PybindAdaptors.h - Interop with MLIR APIs via pybind11 -------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,9 +6,10 @@ // //===----------------------------------------------------------------------===// // This file contains adaptors for clients of the core MLIR Python APIs to -// interop via MLIR CAPI types. The facilities here do not depend on -// implementation details of the MLIR Python API and do not introduce C++-level -// dependencies with it (requiring only Python and CAPI-level dependencies). +// interop via MLIR CAPI types, using pybind11. The facilities here do not +// depend on implementation details of the MLIR Python API and do not introduce +// C++-level dependencies with it (requiring only Python and CAPI-level +// dependencies). // // It is encouraged to be used both in-tree and out-of-tree. For in-tree use // cases, it should be used for dialect implementations (versus relying on @@ -611,40 +612,6 @@ class mlir_value_subclass : public pure_subclass { } // namespace adaptors -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); - }; - MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; - mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index 42a4c8c07..cccf1370b 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir-c/Dialect/LLVM.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include namespace py = pybind11; using namespace llvm; diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f6b4532b1..0c8c0e0a9 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -10,14 +10,15 @@ // //===----------------------------------------------------------------------===// +#include +#include + #include "mlir-c/Dialect/Transform/Interpreter.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include - namespace py = pybind11; namespace { diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 23187f256..e1b870b53 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -683,7 +683,9 @@ if(MLIR_INCLUDE_TESTS) MLIRPythonTestSources.Dialects.PythonTest ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" ADD_TO_PARENT MLIRPythonTestSources.Dialects - SOURCES dialects/python_test.py) + SOURCES + dialects/python_test.py + ) set(LLVM_TARGET_DEFINITIONS "${MLIR_MAIN_SRC_DIR}/test/python/python_test_ops.td") mlir_tablegen( @@ -697,12 +699,25 @@ if(MLIR_INCLUDE_TESTS) ADD_TO_PARENT MLIRPythonTestSources.Dialects.PythonTest SOURCES "dialects/_python_test_ops_gen.py") - declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtension - MODULE_NAME _mlirPythonTest + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionPybind11 + MODULE_NAME _mlirPythonTestPybind11 + ADD_TO_PARENT MLIRPythonTestSources.Dialects + ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY pybind11 + SOURCES + PythonTestModulePybind11.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIPythonTestDialect + ) + declare_mlir_python_extension(MLIRPythonTestSources.PythonTestExtensionNanobind + MODULE_NAME _mlirPythonTestNanobind ADD_TO_PARENT MLIRPythonTestSources.Dialects ROOT_DIR "${MLIR_SOURCE_DIR}/test/python/lib" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES - PythonTestModule.cpp + PythonTestModuleNanobind.cpp PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS diff --git a/mlir/python/mlir/dialects/python_test.py b/mlir/python/mlir/dialects/python_test.py index b5baa80bc..9380896c8 100644 --- a/mlir/python/mlir/dialects/python_test.py +++ b/mlir/python/mlir/dialects/python_test.py @@ -3,15 +3,14 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from ._python_test_ops_gen import * -from .._mlir_libs._mlirPythonTest import ( - TestAttr, - TestType, - TestTensorValue, - TestIntegerRankedTensorType, -) -def register_python_test_dialect(registry): - from .._mlir_libs import _mlirPythonTest +def register_python_test_dialect(registry, use_nanobind): + if use_nanobind: + from .._mlir_libs import _mlirPythonTestNanobind - _mlirPythonTest.register_dialect(registry) + _mlirPythonTestNanobind.register_dialect(registry) + else: + from .._mlir_libs import _mlirPythonTestPybind11 + + _mlirPythonTestPybind11.register_dialect(registry) diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 272d06683..ab8a91229 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,3 +1,4 @@ +nanobind>=2.0, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From 7e217ce5458e46e7f5c1989a565641febb14b323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eliud=20de=20Le=C3=B3n?= Date: Wed, 11 Dec 2024 10:07:21 -0800 Subject: [PATCH 806/915] [mlir][emitc] Add support for C-API/python binding to EmitC dialect (#119476) Added EmitC dialect bindings. --- mlir/include/mlir-c/Dialect/EmitC.h | 26 ++++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/EmitC.cpp | 13 +++++++++++++ mlir/python/CMakeLists.txt | 8 ++++++++ mlir/python/mlir/dialects/EmitC.td | 14 ++++++++++++++ mlir/python/mlir/dialects/emitc.py | 5 +++++ 6 files changed, 75 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/EmitC.h create mode 100644 mlir/lib/CAPI/Dialect/EmitC.cpp create mode 100644 mlir/python/mlir/dialects/EmitC.td create mode 100644 mlir/python/mlir/dialects/emitc.py diff --git a/mlir/include/mlir-c/Dialect/EmitC.h b/mlir/include/mlir-c/Dialect/EmitC.h new file mode 100644 index 000000000..82e698344 --- /dev/null +++ b/mlir/include/mlir-c/Dialect/EmitC.h @@ -0,0 +1,26 @@ +//===-- mlir-c/Dialect/EmitC.h - C API for EmitC dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_EmitC_H +#define MLIR_C_DIALECT_EmitC_H + +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(EmitC, emitc); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_EmitC_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 4e141b60f..5ad4bafed 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -40,6 +40,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIControlFlow MLIRControlFlowDialect ) +add_mlir_upstream_c_api_library(MLIRCAPIEmitC + EmitC.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIREmitCDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIMath Math.cpp diff --git a/mlir/lib/CAPI/Dialect/EmitC.cpp b/mlir/lib/CAPI/Dialect/EmitC.cpp new file mode 100644 index 000000000..3dcb7038a --- /dev/null +++ b/mlir/lib/CAPI/Dialect/EmitC.cpp @@ -0,0 +1,13 @@ +//===- EmitC.cpp - C Interface for EmitC dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/EmitC.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/EmitC/IR/EmitC.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(EmitC, emitc, mlir::emitc::EmitCDialect) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index e1b870b53..10866c11b 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -352,6 +352,14 @@ declare_mlir_python_sources( dialects/quant.py _mlir_libs/_mlir/dialects/quant.pyi) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/EmitC.td + SOURCES + dialects/emitc.py + DIALECT_NAME emitc) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/EmitC.td b/mlir/python/mlir/dialects/EmitC.td new file mode 100644 index 000000000..ff0a56d15 --- /dev/null +++ b/mlir/python/mlir/dialects/EmitC.td @@ -0,0 +1,14 @@ +//===-- EmitC.td - Entry point for EmitC bind --------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_EMITC +#define PYTHON_BINDINGS_EMITC + +include "mlir/Dialect/EmitC/IR/EmitC.td" + +#endif diff --git a/mlir/python/mlir/dialects/emitc.py b/mlir/python/mlir/dialects/emitc.py new file mode 100644 index 000000000..99c3286e5 --- /dev/null +++ b/mlir/python/mlir/dialects/emitc.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._emitc_ops_gen import * From b59ef1ba62ff7e658e0f2b4fc1a819f4fe6abd87 Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 15 Dec 2024 16:02:55 -0600 Subject: [PATCH 807/915] MLIR-C: Add accessor for LLVM array type (#119998) --- mlir/include/mlir-c/Dialect/LLVM.h | 3 +++ mlir/lib/CAPI/Dialect/LLVM.cpp | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index ed9b23c34..0992285f9 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -37,6 +37,9 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMVoidTypeGet(MlirContext ctx); MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements); +/// Returns the element type of the llvm.array type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMArrayTypeGetElementType(MlirType type); + /// Creates an llvm.func type. MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index c7082445d..6ed82ba1a 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -43,6 +43,10 @@ MlirType mlirLLVMArrayTypeGet(MlirType elementType, unsigned numElements) { return wrap(LLVMArrayType::get(unwrap(elementType), numElements)); } +MlirType mlirLLVMArrayTypeGetElementType(MlirType type) { + return wrap(cast(unwrap(type)).getElementType()); +} + MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MlirType const *argumentTypes, bool isVarArg) { SmallVector argumentStorage; From 4c9b06adfac283c4027612ac5e010ef5d09a158b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Dec 2024 14:16:11 -0500 Subject: [PATCH 808/915] [mlir python] Port Python core code to nanobind. (#118583) Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing `pybind11::` to `nanobind::`. Notes: * this PR needs Nanobind 2.4.0, because it needs a bug fix (https://github.com/wjakob/nanobind/pull/806) that landed in that release. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in `PybindAdapters.h`. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (`nb::bytes`) from strings (e.g., `std::string`). This required nb::bytes overloads in a few places. --- mlir/include/mlir/Bindings/Python/IRTypes.h | 2 +- .../mlir/Bindings/Python/PybindAdaptors.h | 10 +- mlir/lib/Bindings/Python/Globals.h | 39 +- mlir/lib/Bindings/Python/IRAffine.cpp | 265 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 672 +++++--- mlir/lib/Bindings/Python/IRCore.cpp | 1412 +++++++++-------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 171 +- mlir/lib/Bindings/Python/IRModule.cpp | 57 +- mlir/lib/Bindings/Python/IRModule.h | 332 ++-- mlir/lib/Bindings/Python/IRTypes.cpp | 200 +-- mlir/lib/Bindings/Python/MainModule.cpp | 56 +- .../Python/{PybindUtils.h => NanobindUtils.h} | 84 +- mlir/lib/Bindings/Python/Pass.cpp | 58 +- mlir/lib/Bindings/Python/Pass.h | 4 +- mlir/lib/Bindings/Python/Rewrite.cpp | 43 +- mlir/lib/Bindings/Python/Rewrite.h | 4 +- mlir/python/CMakeLists.txt | 3 +- mlir/python/requirements.txt | 2 +- 18 files changed, 1853 insertions(+), 1561 deletions(-) rename mlir/lib/Bindings/Python/{PybindUtils.h => NanobindUtils.h} (85%) diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index 9afad4c23..ba9642cf2 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace mlir { diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index c8233355d..edc69774b 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -374,9 +374,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::staticmethod(cf); return *this; } @@ -387,9 +386,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); return *this; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index a022067f5..0ec522d14 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,18 +9,17 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H -#include "PybindUtils.h" +#include +#include +#include +#include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" -#include -#include -#include - namespace mlir { namespace python { @@ -57,55 +56,55 @@ class PyGlobals { /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc, + nanobind::callable pyFunc, bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. - void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, + void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace = false); /// Adds a user-friendly value caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. void registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, + nanobind::callable valueCaster, bool replace = false); /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - pybind11::object pyClass); + nanobind::object pyClass); /// Adds a concrete implementation operation class. /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, bool replace = false); + nanobind::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. - std::optional + std::optional lookupAttributeBuilder(const std::string &attributeKind); /// Returns the custom type caster for MlirTypeID mlirTypeID. - std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Returns the custom value caster for MlirTypeID mlirTypeID. - std::optional lookupValueCaster(MlirTypeID mlirTypeID, + std::optional lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - std::optional + std::optional lookupDialectClass(const std::string &dialectNamespace); /// Looks up a registered operation class (deriving from OpView) by operation /// name. Note that this may trigger a load of the dialect, which can /// arbitrarily re-enter. - std::optional + std::optional lookupOperationClass(llvm::StringRef operationName); private: @@ -113,15 +112,15 @@ class PyGlobals { /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. - llvm::StringMap dialectClassMap; + llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. - llvm::StringMap operationClassMap; + llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; + llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. - llvm::DenseMap typeCasterMap; + llvm::DenseMap typeCasterMap; /// Map of MlirTypeID to custom value caster. - llvm::DenseMap valueCasterMap; + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index b138e131e..2db690309 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include +#include #include #include #include #include "IRModule.h" - -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -30,7 +29,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -46,23 +45,23 @@ static const char kDumpDocstring[] = /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template -static void pyListToVector(const py::list &list, +static void pyListToVector(const nb::list &list, llvm::SmallVectorImpl &result, StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { + result.reserve(nb::len(list)); + for (nb::handle item : list) { try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { + result.push_back(nb::cast(item)); + } catch (nb::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } } @@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirAffineExpr); PyConcreteAffineExpr() = default; @@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy { static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast affine expression to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast affine expression to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig; } - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::arg("expr")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::arg("expr")); cls.def_static( "isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { return DerivedTy::isaFunction(otherAffineExpr); }, - py::arg("other")); + nb::arg("other")); DerivedTy::bindDerived(cls); } @@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("value", [](PyAffineConstantExpr &self) { return mlirAffineConstantExprGetValue(self); }); } @@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { + c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineDimExpr &self) { return mlirAffineDimExprGetPosition(self); }); } @@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { return mlirAffineSymbolExprGetPosition(self); }); } @@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); + c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); } }; @@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); } -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); +nb::object PyAffineExpr::getCapsule() { + return nb::steal(mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineExpr( PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), rawAffineExpr); @@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); } -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +nb::object PyAffineMap::getCapsule() { + return nb::steal(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineMap( PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), rawAffineMap); @@ -454,11 +453,10 @@ class PyIntegerSetConstraint { bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint", - py::module_local()) - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + static void bind(nb::module_ &m) { + nb::class_(m, "IntegerSetConstraint") + .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) + .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); } private: @@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); } -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); +nb::object PyIntegerSet::getCapsule() { + return nb::steal(mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); + throw nb::python_error(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } -void mlir::python::populateIRAffine(py::module &m) { +void mlir::python::populateIRAffine(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) + nb::class_(m, "AffineExpr") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) .def("__add__", &PyAffineAddExpr::get) .def("__add__", &PyAffineAddExpr::getRHSConstant) @@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) { .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) + [](PyAffineExpr &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineExpr &self) { PyPrintAccumulator printAccum; @@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyAffineExpr &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) .def("compose", @@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) { .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, "Gets an affine expression containing the rounded-up result " "of dividing an expression by a constant.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), + .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none(), "Gets a constant affine expression with the given value.") .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_dim", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a dimension at the given position.") .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a symbol at the given position.") .def( "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, @@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- - py::class_(m, "AffineMap", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) + nb::class_(m, "AffineMap") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) .def("__eq__", [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineMap &self) { PyPrintAccumulator printAccum; @@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) { return static_cast(llvm::hash_value(self.get().ptr)); }) .def_static("compress_unused_symbols", - [](py::list affineMaps, DefaultingPyMlirContext context) { + [](nb::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; pyListToVector( affineMaps, maps, "attempting to create an AffineMap"); @@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) { res.emplace_back(context->getRef(), m); return res; }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, "Context that owns the Affine Map") @@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( @@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.size(), affineExprs.data()); return PyAffineMap(context->getRef(), map); }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), + nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), + nb::arg("context").none() = nb::none(), "Gets a map with the given expressions as results.") .def_static( "get_constant", @@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapConstantGet(context->get(), value); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an affine map with a single constant result") .def_static( "get_empty", @@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) { MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("context") = py::none(), "Gets an empty affine map.") + nb::arg("context").none() = nb::none(), "Gets an empty affine map.") .def_static( "get_identity", [](intptr_t nDims, DefaultingPyMlirContext context) { @@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMultiDimIdentityGet(context->get(), nDims); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("context").none() = nb::none(), "Gets an identity map with the given number of dimensions.") .def_static( "get_minor_identity", @@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("n_results"), + nb::arg("context").none() = nb::none(), "Gets a minor identity map with the given number of dimensions and " "results.") .def_static( @@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) { [](std::vector permutation, DefaultingPyMlirContext context) { if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); + throw std::runtime_error("Invalid permutation when attempting to " + "create an AffineMap"); MlirAffineMap affineMap = mlirAffineMapPermutationGet( context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("permutation"), py::arg("context") = py::none(), + nb::arg("permutation"), nb::arg("context").none() = nb::none(), "Gets an affine map that permutes its inputs.") .def( "get_submap", @@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) { intptr_t numResults = mlirAffineMapGetNumResults(self); for (intptr_t pos : resultPos) { if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); + throw nb::value_error("result position out of bounds"); } MlirAffineMap affineMap = mlirAffineMapGetSubMap( self, resultPos.size(), resultPos.data()); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("result_positions")) + nb::arg("result_positions")) .def( "get_major_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMajorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "get_minor_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "replace", [](PyAffineMap &self, PyAffineExpr &expression, @@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) { self, expression, replacement, numResultDims, numResultSyms); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), - py::arg("n_result_syms")) - .def_property_readonly( + nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), + nb::arg("n_result_syms")) + .def_prop_ro( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( + .def_prop_ro("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_prop_ro( "n_dims", [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); + .def_prop_ro("results", + [](PyAffineMap &self) { return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) + nb::class_(m, "IntegerSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyIntegerSet &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( @@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, std::vector eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) - throw py::value_error( + throw nb::value_error( "Expected the number of constraints to match " "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); + if (exprs.size() == 0) + throw nb::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not @@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), + nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, @@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) { mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), + nb::arg("context").none() = nb::none()) .def( "get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); @@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) { numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }, - py::arg("dim_exprs"), py::arg("symbol_exprs"), - py::arg("num_result_dims"), py::arg("num_result_symbols")) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( + nb::arg("dim_exprs"), nb::arg("symbol_exprs"), + nb::arg("num_result_dims"), nb::arg("num_result_symbols")) + .def_prop_ro("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_prop_ro( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { + .def_prop_ro("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_prop_ro("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_prop_ro("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index cc9532f4e..f9656eb23 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,23 +6,29 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include + +#include #include +#include #include #include #include "IRModule.h" - -#include "PybindUtils.h" -#include - -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/raw_ostream.h" - +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" -namespace py = pybind11; +namespace nb = nanobind; +using namespace nanobind::literals; using namespace mlir; using namespace mlir::python; @@ -123,10 +129,119 @@ subsequent processing. namespace { +struct nb_buffer_info { + void *ptr = nullptr; + ssize_t itemsize = 0; + ssize_t size = 0; + const char *format = nullptr; + ssize_t ndim = 0; + SmallVector shape; + SmallVector strides; + bool readonly = false; + + nb_buffer_info( + void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, + SmallVector shape_in, SmallVector strides_in, + bool readonly = false, + std::unique_ptr owned_view_in = + std::unique_ptr(nullptr, nullptr)) + : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), + readonly(readonly), owned_view(std::move(owned_view_in)) { + size = 1; + for (ssize_t i = 0; i < ndim; ++i) { + size *= shape[i]; + } + } + + explicit nb_buffer_info(Py_buffer *view) + : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + // TODO(phawkins): check for null strides + {view->strides, view->strides + view->ndim}, + view->readonly != 0, + std::unique_ptr( + view, PyBuffer_Release)) {} + + nb_buffer_info(const nb_buffer_info &) = delete; + nb_buffer_info(nb_buffer_info &&) = default; + nb_buffer_info &operator=(const nb_buffer_info &) = delete; + nb_buffer_info &operator=(nb_buffer_info &&) = default; + +private: + std::unique_ptr owned_view; +}; + +class nb_buffer : public nb::object { + NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + + nb_buffer_info request() const { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + auto *view = new Py_buffer(); + if (PyObject_GetBuffer(ptr(), view, flags) != 0) { + delete view; + throw nb::python_error(); + } + return nb_buffer_info(view); + } +}; + +template +struct nb_format_descriptor {}; + +template <> +struct nb_format_descriptor { + static const char *format() { return "?"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "b"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "B"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "h"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "H"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "i"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "I"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "Q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "f"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "d"; } +}; + static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; @@ -142,9 +257,9 @@ class PyAffineMapAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); return PyAffineMapAttribute(affineMap.getContext(), attr); }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_property_readonly("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); + nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_prop_ro("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); } }; @@ -164,25 +279,24 @@ class PyIntegerSetAttribute MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); return PyIntegerSetAttribute(integerSet.getContext(), attr); }, - py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); } }; template -static T pyTryCast(py::handle object) { +static T pyTryCast(nb::handle object) { try { - return object.cast(); - } catch (py::cast_error &err) { - std::string msg = - std::string( - "Invalid attribute when attempting to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + return nb::cast(object); + } catch (nb::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = std::string("Invalid attribute (None?) when attempting " "to create an ArrayAttribute (") + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } @@ -205,14 +319,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { EltTy dunderNext() { // Throw if the index has reached the end. if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return DerivedT::getElement(attr.get(), nextIndex++); } /// Bind the iterator class. - static void bind(py::module &m) { - py::class_(m, DerivedT::pyIteratorName, - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, DerivedT::pyIteratorName) .def("__iter__", &PyDenseArrayIterator::dunderIter) .def("__next__", &PyDenseArrayIterator::dunderNext); } @@ -230,17 +343,35 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { /// Bind the attribute class. static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { // Bind the constructor. - c.def_static( - "get", - [](const std::vector &values, DefaultingPyMlirContext ctx) { - return getAttribute(values, ctx->getRef()); - }, - py::arg("values"), py::arg("context") = py::none(), - "Gets a uniqued dense array attribute"); + if constexpr (std::is_same_v) { + c.def_static( + "get", + [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { + std::vector values; + for (nb::handle py_value : py_values) { + int is_true = PyObject_IsTrue(py_value.ptr()); + if (is_true < 0) { + throw nb::python_error(); + } + values.push_back(is_true); + } + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } else { + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } // Bind the array methods. c.def("__getitem__", [](DerivedT &arr, intptr_t i) { if (i >= mlirDenseArrayGetNumElements(arr)) - throw py::index_error("DenseArray index out of range"); + throw nb::index_error("DenseArray index out of range"); return arr.getItem(i); }); c.def("__len__", [](const DerivedT &arr) { @@ -248,13 +379,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { }); c.def("__iter__", [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, const py::list &extras) { + c.def("__add__", [](DerivedT &arr, const nb::list &extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); - values.reserve(numOldElements + py::len(extras)); + values.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) values.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) values.push_back(pyTryCast(attr)); return getAttribute(values, arr.getContext()); }); @@ -358,13 +489,12 @@ class PyArrayAttribute : public PyConcreteAttribute { MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return mlirArrayAttrGetElement(attr.get(), nextIndex++); } - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator", - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "ArrayAttributeIterator") .def("__iter__", &PyArrayAttributeIterator::dunderIter) .def("__next__", &PyArrayAttributeIterator::dunderNext); } @@ -381,9 +511,9 @@ class PyArrayAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](py::list attributes, DefaultingPyMlirContext context) { + [](nb::list attributes, DefaultingPyMlirContext context) { SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); + mlirAttributes.reserve(nb::len(attributes)); for (auto attribute : attributes) { mlirAttributes.push_back(pyTryCast(attribute)); } @@ -391,12 +521,12 @@ class PyArrayAttribute : public PyConcreteAttribute { context->get(), mlirAttributes.size(), mlirAttributes.data()); return PyArrayAttribute(context->getRef(), attr); }, - py::arg("attributes"), py::arg("context") = py::none(), + nb::arg("attributes"), nb::arg("context").none() = nb::none(), "Gets a uniqued Array attribute"); c.def("__getitem__", [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); + throw nb::index_error("ArrayAttribute index out of range"); return arr.getItem(i); }) .def("__len__", @@ -406,13 +536,13 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, py::list extras) { + c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { std::vector attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); - attributes.reserve(numOldElements + py::len(extras)); + attributes.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) attributes.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) attributes.push_back(pyTryCast(attr)); MlirAttribute arrayAttr = mlirArrayAttrGet( arr.getContext()->get(), attributes.size(), attributes.data()); @@ -440,7 +570,7 @@ class PyFloatAttribute : public PyConcreteAttribute { throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), + nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", @@ -449,7 +579,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF32TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", @@ -458,10 +588,10 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF64TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly("value", mlirFloatAttrGetValueDouble, - "Returns the value of the float attribute"); + c.def_prop_ro("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); c.def("__float__", mlirFloatAttrGetValueDouble, "Converts the value of the float attribute to a Python float"); } @@ -481,20 +611,20 @@ class PyIntegerAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirIntegerAttrGet(type, value); return PyIntegerAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly("value", toPyInt, - "Returns the value of the integer attribute"); + c.def_prop_ro("value", toPyInt, + "Returns the value of the integer attribute"); c.def("__int__", toPyInt, "Converts the value of the integer attribute to a Python int"); - c.def_property_readonly_static("static_typeid", - [](py::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + c.def_prop_ro_static("static_typeid", + [](nb::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } private: - static py::int_ toPyInt(PyIntegerAttribute &self) { + static int64_t toPyInt(PyIntegerAttribute &self) { MlirType type = mlirAttributeGetType(self); if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) return mlirIntegerAttrGetValueInt(self); @@ -518,10 +648,10 @@ class PyBoolAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirBoolAttrGet(context->get(), value); return PyBoolAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued bool attribute"); - c.def_property_readonly("value", mlirBoolAttrGetValue, - "Returns the value of the bool attribute"); + c.def_prop_ro("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); c.def("__bool__", mlirBoolAttrGetValue, "Converts the value of the bool attribute to a Python bool"); } @@ -555,9 +685,9 @@ class PySymbolRefAttribute : public PyConcreteAttribute { DefaultingPyMlirContext context) { return PySymbolRefAttribute::fromList(symbols, context.resolve()); }, - py::arg("symbols"), py::arg("context") = py::none(), + nb::arg("symbols"), nb::arg("context").none() = nb::none(), "Gets a uniqued SymbolRef attribute from a list of symbol names"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PySymbolRefAttribute &self) { std::vector symbols = { @@ -589,13 +719,13 @@ class PyFlatSymbolRefAttribute mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyFlatSymbolRefAttribute &self) { MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the FlatSymbolRef attribute as a string"); } @@ -612,29 +742,29 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, py::buffer buffer, PyType &type, + [](std::string dialectNamespace, nb_buffer buffer, PyType &type, DefaultingPyMlirContext context) { - const py::buffer_info bufferInfo = buffer.request(); + const nb_buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( context->get(), toMlirStringRef(dialectNamespace), bufferSize, static_cast(bufferInfo.ptr), type); return PyOpaqueAttribute(context->getRef(), attr); }, - py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), - py::arg("context") = py::none(), "Gets an Opaque attribute."); - c.def_property_readonly( + nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), + nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + c.def_prop_ro( "dialect_namespace", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque attribute as a string"); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the data for the Opaqued attributes as `bytes`"); } @@ -656,7 +786,16 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](nb::bytes value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued string attribute"); c.def_static( "get_typed", @@ -665,20 +804,20 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); - c.def_property_readonly( + c.def_prop_ro( "value_bytes", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); } @@ -693,12 +832,11 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(py::list attributes, std::optional explicitType, + getFromList(nb::list attributes, std::optional explicitType, DefaultingPyMlirContext contextWrapper) { - - const size_t numAttributes = py::len(attributes); + const size_t numAttributes = nb::len(attributes); if (numAttributes == 0) - throw py::value_error("Attributes list must be non-empty."); + throw nb::value_error("Attributes list must be non-empty."); MlirType shapedType; if (explicitType) { @@ -708,8 +846,8 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " - << py::repr(py::cast(*explicitType)); - throw py::value_error(message); + << nb::cast(nb::repr(nb::cast(*explicitType))); + throw nb::value_error(message.c_str()); } shapedType = *explicitType; } else { @@ -722,7 +860,7 @@ class PyDenseElementsAttribute SmallVector mlirAttributes; mlirAttributes.reserve(numAttributes); - for (const py::handle &attribute : attributes) { + for (const nb::handle &attribute : attributes) { MlirAttribute mlirAttribute = pyTryCast(attribute); MlirType attrType = mlirAttributeGetType(mlirAttribute); mlirAttributes.push_back(mlirAttribute); @@ -731,9 +869,11 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "All attributes must be of the same type and match " - << "the type parameter: expected=" << py::repr(py::cast(shapedType)) - << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(message); + << "the type parameter: expected=" + << nb::cast(nb::repr(nb::cast(shapedType))) + << ", but got=" + << nb::cast(nb::repr(nb::cast(attrType))); + throw nb::value_error(message.c_str()); } } @@ -744,7 +884,7 @@ class PyDenseElementsAttribute } static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, + getFromBuffer(nb_buffer array, bool signless, std::optional explicitType, std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { @@ -755,7 +895,7 @@ class PyDenseElementsAttribute } Py_buffer view; if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); @@ -778,25 +918,25 @@ class PyDenseElementsAttribute if (!mlirAttributeIsAInteger(elementAttr) && !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + throw nb::value_error(message.c_str()); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); if (!mlirTypeEqual(shapedElementType, attrType)) { std::string message = "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } MlirAttribute elements = @@ -806,7 +946,7 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - py::buffer_info accessBuffer() { + std::unique_ptr accessBuffer() { MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -889,32 +1029,36 @@ class PyDenseElementsAttribute static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("type") = py::none(), py::arg("shape") = py::none(), - py::arg("context") = py::none(), + nb::arg("array"), nb::arg("signless") = true, + nb::arg("type").none() = nb::none(), + nb::arg("shape").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetDocstring) .def_static("get", PyDenseElementsAttribute::getFromList, - py::arg("attrs"), py::arg("type") = py::none(), - py::arg("context") = py::none(), + nb::arg("attrs"), nb::arg("type").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), + nb::arg("shaped_type"), nb::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def("get_splat_value", - [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw py::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); + .def_prop_ro("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }); } + static PyType_Slot slots[]; + private: + static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); + static void bf_releasebuffer(PyObject *, Py_buffer *buffer); + static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; @@ -1039,27 +1183,27 @@ class PyDenseElementsAttribute return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); } - // There is a complication for boolean numpy arrays, as numpy represents them - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans - // per byte. + // There is a complication for boolean numpy arrays, as numpy represents + // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 + // booleans per byte. static MlirAttribute getBitpackedAttributeFromBooleanBuffer( Py_buffer &view, std::optional> explicitShape, MlirContext &context) { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a bit-packed MLIR attribute is " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a bit-packed MLIR attribute is " "unsupported on big-endian systems"); } + nb::ndarray, nb::c_contig> unpackedArray( + /*data=*/static_cast(view.buf), + /*shape=*/{static_cast(view.len)}); - py::array_t unpackedArray(view.len, - static_cast(view.buf)); - - py::module numpy = py::module::import("numpy"); - py::object packbitsFunc = numpy.attr("packbits"); - py::object packedBooleans = - packbitsFunc(unpackedArray, "bitorder"_a = "little"); - py::buffer_info pythonBuffer = packedBooleans.cast().request(); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object packbitsFunc = numpy.attr("packbits"); + nb::object packedBooleans = + packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); + nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); @@ -1073,11 +1217,11 @@ class PyDenseElementsAttribute // This does the opposite transformation of // `getBitpackedAttributeFromBooleanBuffer` - py::buffer_info getBooleanBufferFromBitpackedAttribute() { + std::unique_ptr getBooleanBufferFromBitpackedAttribute() { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a numpy array from a MLIR attribute " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a numpy array from a MLIR attribute " "is unsupported on big-endian systems"); } @@ -1085,21 +1229,24 @@ class PyDenseElementsAttribute int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); uint8_t *bitpackedData = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); - py::array_t packedArray(numBitpackedBytes, bitpackedData); + nb::ndarray, nb::c_contig> packedArray( + /*data=*/bitpackedData, + /*shape=*/{static_cast(numBitpackedBytes)}); - py::module numpy = py::module::import("numpy"); - py::object unpackbitsFunc = numpy.attr("unpackbits"); - py::object equalFunc = numpy.attr("equal"); - py::object reshapeFunc = numpy.attr("reshape"); - py::array unpackedBooleans = - unpackbitsFunc(packedArray, "bitorder"_a = "little"); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object unpackbitsFunc = numpy.attr("unpackbits"); + nb::object equalFunc = numpy.attr("equal"); + nb::object reshapeFunc = numpy.attr("reshape"); + nb::object unpackedBooleans = + unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. // We need to: // 1. Slice away the padded bits // 2. Make the boolean array have the correct shape // 3. Convert the array to a boolean array - unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; + unpackedBooleans = unpackedBooleans[nb::slice( + nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; unpackedBooleans = equalFunc(unpackedBooleans, 1); MlirType shapedType = mlirAttributeGetType(*this); @@ -1110,15 +1257,15 @@ class PyDenseElementsAttribute } unpackedBooleans = reshapeFunc(unpackedBooleans, shape); - // Make sure the returned py::buffer_view claims ownership of the data in + // Make sure the returned nb::buffer_view claims ownership of the data in // `pythonBuffer` so it remains valid when Python reads it - py::buffer pythonBuffer = unpackedBooleans.cast(); - return pythonBuffer.request(); + nb_buffer pythonBuffer = nb::cast(unpackedBooleans); + return std::make_unique(pythonBuffer.request()); } template - py::buffer_info bufferInfo(MlirType shapedType, - const char *explicitFormat = nullptr) { + std::unique_ptr + bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -1142,19 +1289,69 @@ class PyDenseElementsAttribute } strides.push_back(sizeof(Type)); } - std::string format; + const char *format; if (explicitFormat) { format = explicitFormat; } else { - format = py::format_descriptor::format(); + format = nb_format_descriptor::format(); } - return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, - /*readonly=*/true); + return std::make_unique( + data, sizeof(Type), format, rank, std::move(shape), std::move(strides), + /*readonly=*/true); } }; // namespace -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. +PyType_Slot PyDenseElementsAttribute::slots[] = { + {Py_bf_getbuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_getbuffer)}, + {Py_bf_releasebuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_releasebuffer)}, + {0, nullptr}, +}; + +/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, + Py_buffer *view, + int flags) { + view->obj = nullptr; + std::unique_ptr info; + try { + auto *attr = nb::cast(nb::handle(obj)); + info = attr->accessBuffer(); + } catch (nb::python_error &e) { + e.restore(); + nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); + return -1; + } + view->obj = obj; + view->ndim = 1; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = info->itemsize; + for (auto s : info->shape) { + view->len *= s; + } + view->readonly = info->readonly; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = static_cast(info->ndim); + view->strides = info->strides.data(); + view->shape = info->shape.data(); + } + view->suboffsets = nullptr; + view->internal = info.release(); + Py_INCREF(obj); + return 0; +} + +/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, + Py_buffer *view) { + delete reinterpret_cast(view->internal); +} + +/// Refinement of the PyDenseElementsAttribute for attributes containing +/// integer (and boolean) values. Supports element access. class PyDenseIntElementsAttribute : public PyConcreteAttribute { @@ -1163,11 +1360,11 @@ class PyDenseIntElementsAttribute static constexpr const char *pyClassName = "DenseIntElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { + /// Returns the element at the given linear position. Asserts if the index + /// is out of range. + nb::object dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1175,7 +1372,7 @@ class PyDenseIntElementsAttribute assert(mlirTypeIsAInteger(type) && "expected integer element type in dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible + // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. @@ -1183,38 +1380,38 @@ class PyDenseIntElementsAttribute bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetUInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetUInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); } } else { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); } } - throw py::type_error("Unsupported integer type"); + throw nb::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -1231,7 +1428,7 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, + getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, std::optional alignment, bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { @@ -1244,7 +1441,7 @@ class PyDenseResourceElementsAttribute int flags = PyBUF_STRIDES; std::unique_ptr view = std::make_unique(); if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } // This scope releaser will only release if we haven't yet transferred @@ -1289,12 +1486,12 @@ class PyDenseResourceElementsAttribute } static void bindDerived(ClassTy &c) { - c.def_static("get_from_buffer", - PyDenseResourceElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("name"), py::arg("type"), - py::arg("alignment") = py::none(), - py::arg("is_mutable") = false, py::arg("context") = py::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); + c.def_static( + "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, + nb::arg("context").none() = nb::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); } }; @@ -1318,12 +1515,12 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](py::dict attributes, DefaultingPyMlirContext context) { + [](nb::dict attributes, DefaultingPyMlirContext context) { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlirAttr = it.second.cast(); - auto name = it.first.cast(); + for (std::pair it : attributes) { + auto &mlirAttr = nb::cast(it.second); + auto name = nb::cast(it.first); mlirNamedAttributes.push_back(mlirNamedAttributeGet( mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), toMlirStringRef(name)), @@ -1334,18 +1531,18 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - py::arg("value") = py::dict(), py::arg("context") = py::none(), + nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -1365,25 +1562,25 @@ class PyDenseFPElementsAttribute static constexpr const char *pyClassName = "DenseFPElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - py::float_ dunderGetItem(intptr_t pos) { + nb::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible + // elemental type of the attribute. nb::float_ is implicitly constructible // from float and double. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); } if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); } - throw py::type_error("Unsupported floating-point type"); + throw nb::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { @@ -1406,9 +1603,9 @@ class PyTypeAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { + c.def_prop_ro("value", [](PyTypeAttribute &self) { return mlirTypeAttrGetValue(self.get()); }); } @@ -1430,7 +1627,7 @@ class PyUnitAttribute : public PyConcreteAttribute { return PyUnitAttribute(context->getRef(), mlirUnitAttrGet(context->get())); }, - py::arg("context") = py::none(), "Create a Unit attribute."); + nb::arg("context").none() = nb::none(), "Create a Unit attribute."); } }; @@ -1453,7 +1650,8 @@ class PyStridedLayoutAttribute ctx->get(), offset, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), + nb::arg("offset"), nb::arg("strides"), + nb::arg("context").none() = nb::none(), "Gets a strided layout attribute."); c.def_static( "get_fully_dynamic", @@ -1465,16 +1663,17 @@ class PyStridedLayoutAttribute ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("rank"), py::arg("context") = py::none(), - "Gets a strided layout attribute with dynamic offset and strides of a " + nb::arg("rank"), nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute with dynamic offset and strides of " + "a " "given rank."); - c.def_property_readonly( + c.def_prop_ro( "offset", [](PyStridedLayoutAttribute &self) { return mlirStridedLayoutAttrGetOffset(self); }, "Returns the value of the float point attribute"); - c.def_property_readonly( + c.def_prop_ro( "strides", [](PyStridedLayoutAttribute &self) { intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); @@ -1488,63 +1687,64 @@ class PyStridedLayoutAttribute } }; -py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); std::string msg = std::string( "Can't cast unknown element type DenseIntOrFPElementsAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { +nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { if (PyBoolAttribute::isaFunction(pyAttribute)) - return py::cast(PyBoolAttribute(pyAttribute)); + return nb::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) - return py::cast(PyIntegerAttribute(pyAttribute)); + return nb::cast(PyIntegerAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { +nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); if (PySymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PySymbolRefAttribute(pyAttribute)); + return nb::cast(PySymbolRefAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + + ")"; + throw nb::type_error(msg.c_str()); } } // namespace -void mlir::python::populateIRAttributes(py::module &m) { +void mlir::python::populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1562,24 +1762,26 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseArrayAttrGetTypeID(), - pybind11::cpp_function(denseArrayAttributeCaster)); + nb::cast(nb::cpp_function(denseArrayAttributeCaster))); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); + PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), - pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + nb::cast( + nb::cpp_function(denseIntOrFPElementsAttributeCaster))); PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirSymbolRefAttrGetTypeID(), - pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + nb::cast( + nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1590,7 +1792,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyTypeAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirIntegerAttrGetTypeID(), - pybind11::cpp_function(integerOrBoolAttributeCaster)); + nb::cast(nb::cpp_function(integerOrBoolAttributeCaster))); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3e96f8c60..e1c56a398 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,26 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include +#include +#include +#include +#include +#include -#include "Globals.h" -#include "PybindUtils.h" +#include +#include +#include "Globals.h" +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include -#include - -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -190,18 +195,18 @@ operations. /// Helper for creating an @classmethod. template -py::object classmethod(Func f, Args... args) { - py::object cf = py::cpp_function(f, args...); - return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); +nb::object classmethod(Func f, Args... args) { + nb::object cf = nb::cpp_function(f, args...); + return nb::borrow((PyClassMethod_New(cf.ptr()))); } -static py::object +static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, - py::object dialectDescriptor) { + nb::object dialectDescriptor) { auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); if (!dialectClass) { // Use the base class. - return py::cast(PyDialect(std::move(dialectDescriptor))); + return nb::cast(PyDialect(std::move(dialectDescriptor))); } // Create the custom implementation. @@ -212,42 +217,47 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + /// Create a block, using the current location context if no locations are /// specified. -static MlirBlock createBlock(const py::sequence &pyArgTypes, - const std::optional &pyArgLocs) { +static MlirBlock createBlock(const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { SmallVector argTypes; - argTypes.reserve(pyArgTypes.size()); + argTypes.reserve(nb::len(pyArgTypes)); for (const auto &pyType : pyArgTypes) - argTypes.push_back(pyType.cast()); + argTypes.push_back(nb::cast(pyType)); SmallVector argLocs; if (pyArgLocs) { - argLocs.reserve(pyArgLocs->size()); + argLocs.reserve(nb::len(*pyArgLocs)); for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(pyLoc.cast()); + argLocs.push_back(nb::cast(pyLoc)); } else if (!argTypes.empty()) { argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); } if (argTypes.size() != argLocs.size()) - throw py::value_error(("Expected " + Twine(argTypes.size()) + + throw nb::value_error(("Expected " + Twine(argTypes.size()) + " locations, got: " + Twine(argLocs.size())) - .str()); + .str() + .c_str()); return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); } /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } - static void bind(py::module &m) { + static void bind(nb::module_ &m) { // Debug flags. - py::class_(m, "_GlobalDebug", py::module_local()) - .def_property_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + nb::class_(m, "_GlobalDebug") + .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") .def_static( "set_types", [](const std::string &type) { @@ -268,20 +278,20 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static py::function dundeGetItemNamed(const std::string &attributeKind) { + static nb::callable dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(attributeKind); + throw nb::key_error(attributeKind.c_str()); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } - static void bind(py::module &m) { - py::class_(m, "AttrBuilder", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "AttrBuilder") .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, @@ -295,8 +305,8 @@ struct PyAttrBuilderMap { // PyBlock //------------------------------------------------------------------------------ -py::object PyBlock::getCapsule() { - return py::reinterpret_steal(mlirPythonBlockToCapsule(get())); +nb::object PyBlock::getCapsule() { + return nb::steal(mlirPythonBlockToCapsule(get())); } //------------------------------------------------------------------------------ @@ -315,14 +325,14 @@ class PyRegionIterator { PyRegion dunderNext() { operation->checkValid(); if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionIterator") .def("__iter__", &PyRegionIterator::dunderIter) .def("__next__", &PyRegionIterator::dunderNext); } @@ -351,14 +361,14 @@ class PyRegionList { PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds region"); + throw nb::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionSequence", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionSequence") .def("__len__", &PyRegionList::dunderLen) .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); @@ -378,7 +388,7 @@ class PyBlockIterator { PyBlock dunderNext() { operation->checkValid(); if (mlirBlockIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyBlock returnBlock(operation, next); @@ -386,8 +396,8 @@ class PyBlockIterator { return returnBlock; } - static void bind(py::module &m) { - py::class_(m, "BlockIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockIterator") .def("__iter__", &PyBlockIterator::dunderIter) .def("__next__", &PyBlockIterator::dunderNext); } @@ -424,7 +434,7 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -434,24 +444,26 @@ class PyBlockList { block = mlirBlockGetNextInRegion(block); index -= 1; } - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + PyBlock appendBlock(const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } - static void bind(py::module &m) { - py::class_(m, "BlockList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockList") .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, - py::arg("arg_locs") = std::nullopt); + nb::arg("args"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt); } private: @@ -466,10 +478,10 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - py::object dunderNext() { + nb::object dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyOperationRef returnOperation = @@ -478,8 +490,8 @@ class PyOperationIterator { return returnOperation->createOpView(); } - static void bind(py::module &m) { - py::class_(m, "OperationIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationIterator") .def("__iter__", &PyOperationIterator::dunderIter) .def("__next__", &PyOperationIterator::dunderNext); } @@ -515,10 +527,10 @@ class PyOperationList { return count; } - py::object dunderGetItem(intptr_t index) { + nb::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -529,11 +541,11 @@ class PyOperationList { childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } - static void bind(py::module &m) { - py::class_(m, "OperationList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationList") .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) .def("__len__", &PyOperationList::dunderLen); @@ -548,7 +560,7 @@ class PyOpOperand { public: PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - py::object getOwner() { + nb::object getOwner() { MlirOperation owner = mlirOpOperandGetOwner(opOperand); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(owner)); @@ -557,11 +569,10 @@ class PyOpOperand { size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } - static void bind(py::module &m) { - py::class_(m, "OpOperand", py::module_local()) - .def_property_readonly("owner", &PyOpOperand::getOwner) - .def_property_readonly("operand_number", - &PyOpOperand::getOperandNumber); + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperand") + .def_prop_ro("owner", &PyOpOperand::getOwner) + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); } private: @@ -576,15 +587,15 @@ class PyOpOperandIterator { PyOpOperand dunderNext() { if (mlirOpOperandIsNull(opOperand)) - throw py::stop_iteration(); + throw nb::stop_iteration(); PyOpOperand returnOpOperand(opOperand); opOperand = mlirOpOperandGetNextUse(opOperand); return returnOpOperand; } - static void bind(py::module &m) { - py::class_(m, "OpOperandIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperandIterator") .def("__iter__", &PyOpOperandIterator::dunderIter) .def("__next__", &PyOpOperandIterator::dunderNext); } @@ -600,7 +611,7 @@ class PyOpOperandIterator { //------------------------------------------------------------------------------ PyMlirContext::PyMlirContext(MlirContext context) : context(context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -609,41 +620,36 @@ PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into // liveContexts. - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } -py::object PyMlirContext::getCapsule() { - return py::reinterpret_steal(mlirPythonContextToCapsule(get())); +nb::object PyMlirContext::getCapsule() { + return nb::steal(mlirPythonContextToCapsule(get())); } -py::object PyMlirContext::createFromCapsule(py::object capsule) { +nb::object PyMlirContext::createFromCapsule(nb::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) - throw py::error_already_set(); + throw nb::python_error(); return forContext(rawContext).releaseObject(); } -PyMlirContext *PyMlirContext::createNewContextForInit() { - MlirContext context = mlirContextCreateWithThreading(false); - return new PyMlirContext(context); -} - PyMlirContextRef PyMlirContext::forContext(MlirContext context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); + nb::object pyRef = nb::cast(unownedContextWrapper); + assert(pyRef && "cast to nb::object failed"); liveContexts[context.ptr] = unownedContextWrapper; return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } // Use existing. - py::object pyRef = py::cast(it->second); + nb::object pyRef = nb::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } @@ -717,23 +723,23 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } -pybind11::object PyMlirContext::contextEnter() { - return PyThreadContextEntry::pushContext(*this); +nb::object PyMlirContext::contextEnter(nb::object context) { + return PyThreadContextEntry::pushContext(context); } -void PyMlirContext::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyMlirContext::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popContext(*this); } -py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { +nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { // Note that ownership is transferred to the delete callback below by way of // an explicit inc_ref (borrow). PyDiagnosticHandler *pyHandler = new PyDiagnosticHandler(get(), std::move(callback)); - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::take_ownership); + nb::object pyHandlerObject = + nb::cast(pyHandler, nb::rv_policy::take_ownership); pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is @@ -741,17 +747,17 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { auto handlerCallback = +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); - py::object pyDiagnosticObject = - py::cast(pyDiagnostic, py::return_value_policy::take_ownership); + nb::object pyDiagnosticObject = + nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); auto *pyHandler = static_cast(userData); bool result = false; { // Since this can be called from arbitrary C++ contexts, always get the // gil. - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; try { - result = py::cast(pyHandler->callback(pyDiagnostic)); + result = nb::cast(pyHandler->callback(pyDiagnostic)); } catch (std::exception &e) { fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", e.what()); @@ -768,8 +774,7 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { pyHandler->registeredID.reset(); // Decrement reference, balancing the inc_ref() above. - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::reference); + nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); pyHandlerObject.dec_ref(); }; @@ -819,9 +824,9 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { return &stack.back(); } -void PyThreadContextEntry::push(FrameKind frameKind, py::object context, - py::object insertionPoint, - py::object location) { +void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, + nb::object insertionPoint, + nb::object location) { auto &stack = getStack(); stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), std::move(location)); @@ -844,19 +849,19 @@ void PyThreadContextEntry::push(FrameKind frameKind, py::object context, PyMlirContext *PyThreadContextEntry::getContext() { if (!context) return nullptr; - return py::cast(context); + return nb::cast(context); } PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { if (!insertionPoint) return nullptr; - return py::cast(insertionPoint); + return nb::cast(insertionPoint); } PyLocation *PyThreadContextEntry::getLocation() { if (!location) return nullptr; - return py::cast(location); + return nb::cast(location); } PyMlirContext *PyThreadContextEntry::getDefaultContext() { @@ -874,12 +879,11 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() { return tos ? tos->getLocation() : nullptr; } -py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { - py::object contextObj = py::cast(context); - push(FrameKind::Context, /*context=*/contextObj, - /*insertionPoint=*/py::object(), - /*location=*/py::object()); - return contextObj; +nb::object PyThreadContextEntry::pushContext(nb::object context) { + push(FrameKind::Context, /*context=*/context, + /*insertionPoint=*/nb::object(), + /*location=*/nb::object()); + return context; } void PyThreadContextEntry::popContext(PyMlirContext &context) { @@ -892,15 +896,16 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) { stack.pop_back(); } -py::object -PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { - py::object contextObj = +nb::object +PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { + PyInsertionPoint &insertionPoint = + nb::cast(insertionPointObj); + nb::object contextObj = insertionPoint.getBlock().getParentOperation()->getContext().getObject(); - py::object insertionPointObj = py::cast(insertionPoint); push(FrameKind::InsertionPoint, /*context=*/contextObj, /*insertionPoint=*/insertionPointObj, - /*location=*/py::object()); + /*location=*/nb::object()); return insertionPointObj; } @@ -915,11 +920,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { stack.pop_back(); } -py::object PyThreadContextEntry::pushLocation(PyLocation &location) { - py::object contextObj = location.getContext().getObject(); - py::object locationObj = py::cast(location); +nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { + PyLocation &location = nb::cast(locationObj); + nb::object contextObj = location.getContext().getObject(); push(FrameKind::Location, /*context=*/contextObj, - /*insertionPoint=*/py::object(), + /*insertionPoint=*/nb::object(), /*location=*/locationObj); return locationObj; } @@ -941,15 +946,15 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { void PyDiagnostic::invalidate() { valid = false; if (materializedNotes) { - for (auto ¬eObject : *materializedNotes) { - PyDiagnostic *note = py::cast(noteObject); + for (nb::handle noteObject : *materializedNotes) { + PyDiagnostic *note = nb::cast(noteObject); note->invalidate(); } } } PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, - py::object callback) + nb::object callback) : context(context), callback(std::move(callback)) {} PyDiagnosticHandler::~PyDiagnosticHandler() = default; @@ -984,32 +989,36 @@ PyLocation PyDiagnostic::getLocation() { return PyLocation(PyMlirContext::forContext(context), loc); } -py::str PyDiagnostic::getMessage() { +nb::str PyDiagnostic::getMessage() { checkValid(); - py::object fileObject = py::module::import("io").attr("StringIO")(); + nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); PyFileAccumulator accum(fileObject, /*binary=*/false); mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); - return fileObject.attr("getvalue")(); + return nb::cast(fileObject.attr("getvalue")()); } -py::tuple PyDiagnostic::getNotes() { +nb::tuple PyDiagnostic::getNotes() { checkValid(); if (materializedNotes) return *materializedNotes; intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); - materializedNotes = py::tuple(numNotes); + nb::tuple notes = nb::steal(PyTuple_New(numNotes)); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - (*materializedNotes)[i] = PyDiagnostic(noteDiag); + nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); } + materializedNotes = std::move(notes); + return *materializedNotes; } PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { std::vector notes; - for (py::handle n : getNotes()) - notes.emplace_back(n.cast().getInfo()); - return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; + for (nb::handle n : getNotes()) + notes.emplace_back(nb::cast(n).getInfo()); + return {getSeverity(), getLocation(), nb::cast(getMessage()), + std::move(notes)}; } //------------------------------------------------------------------------------ @@ -1023,22 +1032,21 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, if (mlirDialectIsNull(dialect)) { std::string msg = (Twine("Dialect '") + key + "' not found").str(); if (attrError) - throw py::attribute_error(msg); - throw py::index_error(msg); + throw nb::attribute_error(msg.c_str()); + throw nb::index_error(msg.c_str()); } return dialect; } -py::object PyDialectRegistry::getCapsule() { - return py::reinterpret_steal( - mlirPythonDialectRegistryToCapsule(*this)); +nb::object PyDialectRegistry::getCapsule() { + return nb::steal(mlirPythonDialectRegistryToCapsule(*this)); } -PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { +PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { MlirDialectRegistry rawRegistry = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); if (mlirDialectRegistryIsNull(rawRegistry)) - throw py::error_already_set(); + throw nb::python_error(); return PyDialectRegistry(rawRegistry); } @@ -1046,25 +1054,25 @@ PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { // PyLocation //------------------------------------------------------------------------------ -py::object PyLocation::getCapsule() { - return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); +nb::object PyLocation::getCapsule() { + return nb::steal(mlirPythonLocationToCapsule(*this)); } -PyLocation PyLocation::createFromCapsule(py::object capsule) { +PyLocation PyLocation::createFromCapsule(nb::object capsule) { MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); if (mlirLocationIsNull(rawLoc)) - throw py::error_already_set(); + throw nb::python_error(); return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), rawLoc); } -py::object PyLocation::contextEnter() { - return PyThreadContextEntry::pushLocation(*this); +nb::object PyLocation::contextEnter(nb::object locationObj) { + return PyThreadContextEntry::pushLocation(locationObj); } -void PyLocation::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyLocation::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popLocation(*this); } @@ -1087,7 +1095,7 @@ PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} PyModule::~PyModule() { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = getContext()->liveModules; assert(liveModules.count(module.ptr) == 1 && "destroying module not in live map"); @@ -1099,7 +1107,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = contextRef->liveModules; auto it = liveModules.find(module.ptr); if (it == liveModules.end()) { @@ -1108,8 +1116,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedModule, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); unownedModule->handle = pyRef; liveModules[module.ptr] = std::make_pair(unownedModule->handle, unownedModule); @@ -1117,19 +1124,19 @@ PyModuleRef PyModule::forModule(MlirModule module) { } // Use existing. PyModule *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyModuleRef(existing, std::move(pyRef)); } -py::object PyModule::createFromCapsule(py::object capsule) { +nb::object PyModule::createFromCapsule(nb::object capsule) { MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); if (mlirModuleIsNull(rawModule)) - throw py::error_already_set(); + throw nb::python_error(); return forModule(rawModule).releaseObject(); } -py::object PyModule::getCapsule() { - return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); +nb::object PyModule::getCapsule() { + return nb::steal(mlirPythonModuleToCapsule(get())); } //------------------------------------------------------------------------------ @@ -1158,7 +1165,7 @@ PyOperation::~PyOperation() { PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; // Create. PyOperation *unownedOperation = @@ -1166,8 +1173,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedOperation, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership); unownedOperation->handle = pyRef; if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); @@ -1178,7 +1184,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); if (it == liveOperations.end()) { @@ -1188,13 +1194,13 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, } // Use existing. PyOperation *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyOperationRef(existing, std::move(pyRef)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; assert(liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"); @@ -1227,12 +1233,12 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, + bool assumeVerified, nb::object fileObject, bool binary, bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) @@ -1255,18 +1261,18 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::print(PyAsmState &state, py::object fileObject, +void PyOperationBase::print(PyAsmState &state, nb::object fileObject, bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), accum.getUserData()); } -void PyOperationBase::writeBytecode(const py::object &fileObject, +void PyOperationBase::writeBytecode(const nb::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); @@ -1282,9 +1288,10 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, operation, config, accum.getCallback(), accum.getUserData()); mlirBytecodeWriterConfigDestroy(config); if (mlirLogicalResultIsFailure(res)) - throw py::value_error((Twine("Unable to honor desired bytecode version ") + + throw nb::value_error((Twine("Unable to honor desired bytecode version ") + Twine(*bytecodeVersion)) - .str()); + .str() + .c_str()); } void PyOperationBase::walk( @@ -1296,7 +1303,7 @@ void PyOperationBase::walk( std::function callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, @@ -1304,10 +1311,10 @@ void PyOperationBase::walk( UserData *calleeUserData = static_cast(userData); try { return (calleeUserData->callback)(op); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; - calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionWhat = std::string(e.what()); + calleeUserData->exceptionType = nb::borrow(e.type()); return MlirWalkResult::MlirWalkResultInterrupt; } }; @@ -1319,16 +1326,16 @@ void PyOperationBase::walk( } } -py::object PyOperationBase::getAsm(bool binary, +nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions) { - py::object fileObject; + nb::object fileObject; if (binary) { - fileObject = py::module::import("io").attr("BytesIO")(); + fileObject = nb::module_::import_("io").attr("BytesIO")(); } else { - fileObject = py::module::import("io").attr("StringIO")(); + fileObject = nb::module_::import_("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, @@ -1372,7 +1379,7 @@ bool PyOperationBase::verify() { std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw py::value_error("Detached operations have no parent"); + throw nb::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1388,42 +1395,42 @@ PyBlock PyOperation::getBlock() { return PyBlock{std::move(*parentOperation), block}; } -py::object PyOperation::getCapsule() { +nb::object PyOperation::getCapsule() { checkValid(); - return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); + return nb::steal(mlirPythonOperationToCapsule(get())); } -py::object PyOperation::createFromCapsule(py::object capsule) { +nb::object PyOperation::createFromCapsule(nb::object capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) .releaseObject(); } static void maybeInsertOperation(PyOperationRef &op, - const py::object &maybeIp) { + const nb::object &maybeIp) { // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { + if (!maybeIp.is(nb::cast(false))) { PyInsertionPoint *ip; if (maybeIp.is_none()) { ip = PyThreadContextEntry::getDefaultInsertionPoint(); } else { - ip = py::cast(maybeIp); + ip = nb::cast(maybeIp); } if (ip) ip->insert(*op.get()); } } -py::object PyOperation::create(const std::string &name, +nb::object PyOperation::create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp, bool inferType) { + const nb::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1431,14 +1438,14 @@ py::object PyOperation::create(const std::string &name, // General parameter validation. if (regions < 0) - throw py::value_error("number of regions must be >= 0"); + throw nb::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw py::value_error("operand value cannot be None"); + throw nb::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1449,38 +1456,38 @@ py::object PyOperation::create(const std::string &name, for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw py::value_error("result type cannot be None"); + throw nb::value_error("result type cannot be None"); mlirResults.push_back(*result); } } // Unpack/validate attributes. if (attributes) { mlirAttributes.reserve(attributes->size()); - for (auto &it : *attributes) { + for (std::pair it : *attributes) { std::string key; try { - key = it.first.cast(); - } catch (py::cast_error &err) { + key = nb::cast(it.first); + } catch (nb::cast_error &err) { std::string msg = "Invalid attribute key (not a string) when " "attempting to create the operation \"" + name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw nb::type_error(msg.c_str()); } try { - auto &attribute = it.second.cast(); + auto &attribute = nb::cast(it.second); // TODO: Verify attribute originates from the same context. mlirAttributes.emplace_back(std::move(key), attribute); - } catch (py::reference_cast_error &) { + } catch (nb::cast_error &err) { + std::string msg = "Invalid attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + name + "\" (" + err.what() + ")"; + throw nb::type_error(msg.c_str()); + } catch (std::runtime_error &) { // This exception seems thrown when the value is "None". std::string msg = "Found an invalid (`None`?) attribute value for the key \"" + key + "\" when attempting to create the operation \"" + name + "\""; - throw py::cast_error(msg); - } catch (py::cast_error &err) { - std::string msg = "Invalid attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg); } } } @@ -1490,7 +1497,7 @@ py::object PyOperation::create(const std::string &name, for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw py::value_error("successor block cannot be None"); + throw nb::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1535,7 +1542,7 @@ py::object PyOperation::create(const std::string &name, // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw py::value_error("Operation creation failed"); + throw nb::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1543,7 +1550,7 @@ py::object PyOperation::create(const std::string &name, return created.getObject(); } -py::object PyOperation::clone(const py::object &maybeIp) { +nb::object PyOperation::clone(const nb::object &maybeIp) { MlirOperation clonedOperation = mlirOperationClone(operation); PyOperationRef cloned = PyOperation::createDetached(getContext(), clonedOperation); @@ -1552,15 +1559,15 @@ py::object PyOperation::clone(const py::object &maybeIp) { return cloned->createOpView(); } -py::object PyOperation::createOpView() { +nb::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); if (operationCls) - return PyOpView::constructDerived(*operationCls, *getRef().get()); - return py::cast(PyOpView(getRef().getObject())); + return PyOpView::constructDerived(*operationCls, getRef().getObject()); + return nb::cast(PyOpView(getRef().getObject())); } void PyOperation::erase() { @@ -1573,8 +1580,8 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -static void populateResultTypes(StringRef name, py::list resultTypeList, - const py::object &resultSegmentSpecObj, +static void populateResultTypes(StringRef name, nb::list resultTypeList, + const nb::object &resultSegmentSpecObj, std::vector &resultSegmentLengths, std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); @@ -1582,26 +1589,28 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Non-variadic result unpacking. for (const auto &it : llvm::enumerate(resultTypeList)) { try { - resultTypes.push_back(py::cast(it.value())); + resultTypes.push_back(nb::cast(it.value())); if (!resultTypes.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized result unpacking. - auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + auto resultSegmentSpec = nb::cast>(resultSegmentSpecObj); if (resultSegmentSpec.size() != resultTypeList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + " result segments but was provided " + llvm::Twine(resultTypeList.size())) - .str()); + .str() + .c_str()); } resultSegmentLengths.reserve(resultTypeList.size()); for (const auto &it : @@ -1610,7 +1619,7 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *resultType = py::cast(std::get<0>(it.value())); + auto *resultType = nb::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); @@ -1618,14 +1627,20 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Allowed to be optional. resultSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and result is not optional"); + throw nb::value_error( + (llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (was None and result is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1635,72 +1650,75 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, resultSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - resultTypes.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + resultTypes.push_back(nb::cast(segmentItem)); if (!resultTypes.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - resultSegmentLengths.push_back(segment.size()); + resultSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Result ") + + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Types (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } } -py::object PyOpView::buildGeneric( - const py::object &cls, std::optional resultTypeList, - py::list operandList, std::optional attributes, +nb::object PyOpView::buildGeneric( + const nb::object &cls, std::optional resultTypeList, + nb::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const nb::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); + std::string name = nb::cast(cls.attr("OPERATION_NAME")); // Operand and result segment specs are either none, which does no // variadic unpacking, or a list of ints with segment sizes, where each // element is either a positive number (typically 1 for a scalar) or -1 to // indicate that it is derived from the length of the same-indexed operand // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); std::vector operandSegmentLengths; std::vector resultSegmentLengths; // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + auto opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); int opMinRegionCount = std::get<0>(opRegionSpec); bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); if (!regions) { regions = opMinRegionCount; } if (*regions < opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } // Unpack results. @@ -1717,26 +1735,28 @@ py::object PyOpView::buildGeneric( // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(py::cast(it.value())); + operands.push_back(nb::cast(it.value())); if (!operands.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized operand unpacking. - auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); + auto operandSegmentSpec = nb::cast>(operandSegmentSpecObj); if (operandSegmentSpec.size() != operandList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(operandSegmentSpec.size()) + "operand segments but was provided " + llvm::Twine(operandList.size())) - .str()); + .str() + .c_str()); } operandSegmentLengths.reserve(operandList.size()); for (const auto &it : @@ -1745,7 +1765,7 @@ py::object PyOpView::buildGeneric( if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *operandValue = py::cast(std::get<0>(it.value())); + auto *operandValue = nb::cast(std::get<0>(it.value())); if (operandValue) { operands.push_back(operandValue); operandSegmentLengths.push_back(1); @@ -1753,14 +1773,20 @@ py::object PyOpView::buildGeneric( // Allowed to be optional. operandSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and operand is not optional"); + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1770,27 +1796,28 @@ py::object PyOpView::buildGeneric( operandSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - operands.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + operands.push_back(nb::cast(segmentItem)); if (!operands.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - operandSegmentLengths.push_back(segment.size()); + operandSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Operand ") + + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } @@ -1799,13 +1826,13 @@ py::object PyOpView::buildGeneric( if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { // Dup. if (attributes) { - attributes = py::dict(*attributes); + attributes = nb::dict(*attributes); } else { - attributes = py::dict(); + attributes = nb::dict(); } if (attributes->contains("resultSegmentSizes") || attributes->contains("operandSegmentSizes")) { - throw py::value_error("Manually setting a 'resultSegmentSizes' or " + throw nb::value_error("Manually setting a 'resultSegmentSizes' or " "'operandSegmentSizes' attribute is unsupported. " "Use Operation.create for such low-level access."); } @@ -1839,21 +1866,18 @@ py::object PyOpView::buildGeneric( !resultTypeList); } -pybind11::object PyOpView::constructDerived(const pybind11::object &cls, - const PyOperation &operation) { - // TODO: pybind11 2.6 supports a more direct form. - // Upgrade many years from now. - // auto opViewType = py::type::of(); - py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - py::object instance = cls.attr("__new__")(cls); +nb::object PyOpView::constructDerived(const nb::object &cls, + const nb::object &operation) { + nb::handle opViewType = nb::type(); + nb::object instance = cls.attr("__new__")(cls); opViewType.attr("__init__")(instance, operation); return instance; } -PyOpView::PyOpView(const py::object &operationObject) +PyOpView::PyOpView(const nb::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. - : operation(py::cast(operationObject).getOperation()), + : operation(nb::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} //------------------------------------------------------------------------------ @@ -1869,7 +1893,7 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw py::value_error( + throw nb::value_error( "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; @@ -1882,7 +1906,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { // already end in a known terminator (violating this will cause assertion // failures later). if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { - throw py::index_error("Cannot insert operation at the end of a block " + throw nb::index_error("Cannot insert operation at the end of a block " "that already has a terminator. Did you mean to " "use 'InsertionPoint.at_block_terminator(block)' " "versus 'InsertionPoint(block)'?"); @@ -1908,19 +1932,19 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw py::value_error("Block has no terminator"); + throw nb::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; } -py::object PyInsertionPoint::contextEnter() { - return PyThreadContextEntry::pushInsertionPoint(*this); +nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { + return PyThreadContextEntry::pushInsertionPoint(insertPoint); } -void PyInsertionPoint::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyInsertionPoint::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popInsertionPoint(*this); } @@ -1932,14 +1956,14 @@ bool PyAttribute::operator==(const PyAttribute &other) const { return mlirAttributeEqual(attr, other.attr); } -py::object PyAttribute::getCapsule() { - return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); +nb::object PyAttribute::getCapsule() { + return nb::steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(py::object capsule) { +PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAttribute( PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); } @@ -1964,14 +1988,14 @@ bool PyType::operator==(const PyType &other) const { return mlirTypeEqual(type, other.type); } -py::object PyType::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); +nb::object PyType::getCapsule() { + return nb::steal(mlirPythonTypeToCapsule(*this)); } -PyType PyType::createFromCapsule(py::object capsule) { +PyType PyType::createFromCapsule(nb::object capsule) { MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); if (mlirTypeIsNull(rawType)) - throw py::error_already_set(); + throw nb::python_error(); return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), rawType); } @@ -1980,14 +2004,14 @@ PyType PyType::createFromCapsule(py::object capsule) { // PyTypeID. //------------------------------------------------------------------------------ -py::object PyTypeID::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +nb::object PyTypeID::getCapsule() { + return nb::steal(mlirPythonTypeIDToCapsule(*this)); } -PyTypeID PyTypeID::createFromCapsule(py::object capsule) { +PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); if (mlirTypeIDIsNull(mlirTypeID)) - throw py::error_already_set(); + throw nb::python_error(); return PyTypeID(mlirTypeID); } bool PyTypeID::operator==(const PyTypeID &other) const { @@ -1998,36 +2022,36 @@ bool PyTypeID::operator==(const PyTypeID &other) const { // PyValue and subclasses. //------------------------------------------------------------------------------ -pybind11::object PyValue::getCapsule() { - return py::reinterpret_steal(mlirPythonValueToCapsule(get())); +nb::object PyValue::getCapsule() { + return nb::steal(mlirPythonValueToCapsule(get())); } -pybind11::object PyValue::maybeDownCast() { +nb::object PyValue::maybeDownCast() { MlirType type = mlirValueGetType(get()); MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional valueCaster = + std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); - // py::return_value_policy::move means use std::move to move the return value + // nb::rv_policy::move means use std::move to move the return value // contents into a new instance that will be owned by Python. - py::object thisObj = py::cast(this, py::return_value_policy::move); + nb::object thisObj = nb::cast(this, nb::rv_policy::move); if (!valueCaster) return thisObj; return valueCaster.value()(thisObj); } -PyValue PyValue::createFromCapsule(pybind11::object capsule) { +PyValue PyValue::createFromCapsule(nb::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) - throw py::error_already_set(); + throw nb::python_error(); MlirOperation owner; if (mlirValueIsAOpResult(value)) owner = mlirOpResultGetOwner(value); if (mlirValueIsABlockArgument(value)) owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); if (mlirOperationIsNull(owner)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext ctx = mlirOperationGetContext(owner); PyOperationRef ownerRef = PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); @@ -2042,16 +2066,17 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation) : operation(operation.getOperation().getRef()) { symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); if (mlirSymbolTableIsNull(symbolTable)) { - throw py::cast_error("Operation is not a Symbol Table."); + throw nb::type_error("Operation is not a Symbol Table."); } } -py::object PySymbolTable::dunderGetItem(const std::string &name) { +nb::object PySymbolTable::dunderGetItem(const std::string &name) { operation->checkValid(); MlirOperation symbol = mlirSymbolTableLookup( symbolTable, mlirStringRefCreate(name.data(), name.length())); if (mlirOperationIsNull(symbol)) - throw py::key_error("Symbol '" + name + "' not in the symbol table."); + throw nb::key_error( + ("Symbol '" + name + "' not in the symbol table.").c_str()); return PyOperation::forOperation(operation->getContext(), symbol, operation.getObject()) @@ -2069,8 +2094,8 @@ void PySymbolTable::erase(PyOperationBase &symbol) { } void PySymbolTable::dunderDel(const std::string &name) { - py::object operation = dunderGetItem(name); - erase(py::cast(operation)); + nb::object operation = dunderGetItem(name); + erase(nb::cast(operation)); } MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { @@ -2079,7 +2104,7 @@ MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } @@ -2091,7 +2116,7 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return existingNameAttr; } @@ -2104,7 +2129,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); MlirAttribute newNameAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); @@ -2117,7 +2142,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); return existingVisAttr; } @@ -2125,7 +2150,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, const std::string &visibility) { if (visibility != "public" && visibility != "private" && visibility != "nested") - throw py::value_error( + throw nb::value_error( "Expected visibility to be 'public', 'private' or 'nested'"); PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -2133,7 +2158,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(visibility)); mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); @@ -2148,20 +2173,20 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), from.getOperation()))) - throw py::value_error("Symbol rename failed"); + throw nb::value_error("Symbol rename failed"); } void PySymbolTable::walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - py::object callback) { + nb::object callback) { PyOperation &fromOperation = from.getOperation(); fromOperation.checkValid(); struct UserData { PyMlirContextRef context; - py::object callback; + nb::object callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{ fromOperation.getContext(), std::move(callback), false, {}, {}}; @@ -2175,10 +2200,10 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, return; try { calleeUserData->callback(pyFoundOp.getObject(), isVisible); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionType = nb::borrow(e.type()); } }, static_cast(&userData)); @@ -2200,7 +2225,7 @@ class PyConcreteValue : public PyValue { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirValue); PyConcreteValue() = default; @@ -2213,25 +2238,26 @@ class PyConcreteValue : public PyValue { /// type mismatches. static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast value to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig.get(); } /// Binds the Python module objects to functions of this class. - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::keep_alive<0, 1>(), py::arg("value")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", [](PyValue &otherValue) -> bool { return DerivedTy::isaFunction(otherValue); }, - py::arg("other_value")); + nb::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); @@ -2249,11 +2275,11 @@ class PyBlockArgument : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyBlockArgument &self) { + c.def_prop_ro("owner", [](PyBlockArgument &self) { return PyBlock(self.getParentOperation(), mlirBlockArgumentGetOwner(self.get())); }); - c.def_property_readonly("arg_number", [](PyBlockArgument &self) { + c.def_prop_ro("arg_number", [](PyBlockArgument &self) { return mlirBlockArgumentGetArgNumber(self.get()); }); c.def( @@ -2261,7 +2287,7 @@ class PyBlockArgument : public PyConcreteValue { [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - py::arg("type")); + nb::arg("type")); } }; @@ -2273,14 +2299,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyOpResult &self) { + c.def_prop_ro("owner", [](PyOpResult &self) { assert( mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in the IR"); return self.getParentOperation().getObject(); }); - c.def_property_readonly("result_number", [](PyOpResult &self) { + c.def_prop_ro("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); } @@ -2317,7 +2343,7 @@ class PyBlockArgumentList operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyBlockArgumentList &self) { + c.def_prop_ro("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); }); } @@ -2422,10 +2448,10 @@ class PyOpResultList : public Sliceable { operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyOpResultList &self) { + c.def_prop_ro("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_property_readonly("owner", [](PyOpResultList &self) { + c.def_prop_ro("owner", [](PyOpResultList &self) { return self.operation->createOpView(); }); } @@ -2508,14 +2534,14 @@ class PyOpAttributeMap { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); } return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2534,7 +2560,7 @@ class PyOpAttributeMap { int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw py::key_error("attempt to delete a non-existent attribute"); + throw nb::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2546,8 +2572,8 @@ class PyOpAttributeMap { operation->get(), toMlirStringRef(name))); } - static void bind(py::module &m) { - py::class_(m, "OpAttributeMap", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) @@ -2566,21 +2592,21 @@ class PyOpAttributeMap { // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRCore(py::module &m) { +void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- - py::enum_(m, "DiagnosticSeverity", py::module_local()) + nb::enum_(m, "DiagnosticSeverity") .value("ERROR", MlirDiagnosticError) .value("WARNING", MlirDiagnosticWarning) .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); - py::enum_(m, "WalkOrder", py::module_local()) + nb::enum_(m, "WalkOrder") .value("PRE_ORDER", MlirWalkPreOrder) .value("POST_ORDER", MlirWalkPostOrder); - py::enum_(m, "WalkResult", py::module_local()) + nb::enum_(m, "WalkResult") .value("ADVANCE", MlirWalkResultAdvance) .value("INTERRUPT", MlirWalkResultInterrupt) .value("SKIP", MlirWalkResultSkip); @@ -2588,33 +2614,37 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- - py::class_(m, "Diagnostic", py::module_local()) - .def_property_readonly("severity", &PyDiagnostic::getSeverity) - .def_property_readonly("location", &PyDiagnostic::getLocation) - .def_property_readonly("message", &PyDiagnostic::getMessage) - .def_property_readonly("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> py::str { + nb::class_(m, "Diagnostic") + .def_prop_ro("severity", &PyDiagnostic::getSeverity) + .def_prop_ro("location", &PyDiagnostic::getLocation) + .def_prop_ro("message", &PyDiagnostic::getMessage) + .def_prop_ro("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> nb::str { if (!self.isValid()) - return ""; + return nb::str(""); return self.getMessage(); }); - py::class_(m, "DiagnosticInfo", - py::module_local()) - .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) - .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) - .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) - .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) + nb::class_(m, "DiagnosticInfo") + .def("__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }) + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) .def("__str__", [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); - py::class_(m, "DiagnosticHandler", py::module_local()) + nb::class_(m, "DiagnosticHandler") .def("detach", &PyDiagnosticHandler::detach) - .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) - .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) .def("__enter__", &PyDiagnosticHandler::contextEnter) - .def("__exit__", &PyDiagnosticHandler::contextExit); + .def("__exit__", &PyDiagnosticHandler::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()); //---------------------------------------------------------------------------- // Mapping of MlirContext. @@ -2622,8 +2652,12 @@ void mlir::python::populateIRCore(py::module &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- - py::class_(m, "_BaseContext", py::module_local()) - .def(py::init<>(&PyMlirContext::createNewContextForInit)) + nb::class_(m, "_BaseContext") + .def("__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { @@ -2635,28 +2669,28 @@ void mlir::python::populateIRCore(py::module &m) { &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_clear_live_operations_inside", - py::overload_cast( + nb::overload_cast( &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return py::none().cast(); - return py::cast(context); + return nb::none(); + return nb::cast(context); }, "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( + .def_prop_ro( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Gets a container for accessing dialects by name") - .def_property_readonly( + .def_prop_ro( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Alias for 'dialect'") .def( @@ -2665,14 +2699,14 @@ void mlir::python::populateIRCore(py::module &m) { MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw py::value_error( - (Twine("Dialect '") + name + "' not found").str()); + throw nb::value_error( + (Twine("Dialect '") + name + "' not found").str().c_str()); } return PyDialectDescriptor(self.getRef(), dialect); }, - py::arg("dialect_name"), + nb::arg("dialect_name"), "Gets or loads a dialect by name, returning its descriptor object") - .def_property( + .def_prop_rw( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { return mlirContextGetAllowUnregisteredDialects(self.get()); @@ -2681,32 +2715,32 @@ void mlir::python::populateIRCore(py::module &m) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, - py::arg("callback"), + nb::arg("callback"), "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - py::arg("enable")) + nb::arg("enable")) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")) + nb::arg("operation_name")) .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - py::arg("registry")) - .def_property("emit_error_diagnostics", nullptr, - &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") + nb::arg("registry")) + .def_prop_rw("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2714,13 +2748,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor", py::module_local()) - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) + nb::class_(m, "DialectDescriptor") + .def_prop_ro("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }) .def("__repr__", [](PyDialectDescriptor &self) { MlirStringRef ns = mlirDialectGetNamespace(self.get()); std::string repr("(m, "Dialects", py::module_local()) + nb::class_(m, "Dialects") .def("__getitem__", [=](PyDialects &self, std::string keyName) { MlirDialect dialect = self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(keyName, std::move(descriptor)); }) .def("__getattr__", [=](PyDialects &self, std::string attrName) { MlirDialect dialect = self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(attrName, std::move(descriptor)); }); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- - py::class_(m, "Dialect", py::module_local()) - .def(py::init(), py::arg("descriptor")) - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { + nb::class_(m, "Dialect") + .def(nb::init(), nb::arg("descriptor")) + .def_prop_ro("descriptor", + [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](nb::object self) { auto clazz = self.attr("__class__"); - return py::str(""); + return nb::str(""); }); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- - py::class_(m, "DialectRegistry", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyDialectRegistry::getCapsule) + nb::class_(m, "DialectRegistry") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) - .def(py::init<>()); + .def(nb::init<>()); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- - py::class_(m, "Location", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + nb::class_(m, "Location") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) + .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) .def("__eq__", [](PyLocation &self, PyLocation &other) -> bool { return mlirLocationEqual(self, other); }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( + .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw py::value_error("No current Location"); + throw nb::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2801,14 +2834,14 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - py::arg("context") = py::none(), + nb::arg("context").none() = nb::none(), "Gets a Location representing an unknown location") .def_static( "callsite", [](PyLocation callee, const std::vector &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw py::value_error("No caller frames provided"); + throw nb::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -2816,7 +2849,8 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), + nb::arg("callee"), nb::arg("frames"), + nb::arg("context").none() = nb::none(), kContextGetCallSiteLocationDocstring) .def_static( "file", @@ -2827,8 +2861,9 @@ void mlir::python::populateIRCore(py::module &m) { mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) + nb::arg("filename"), nb::arg("line"), nb::arg("col"), + nb::arg("context").none() = nb::none(), + kContextGetFileLocationDocstring) .def_static( "fused", [](const std::vector &pyLocations, @@ -2843,8 +2878,9 @@ void mlir::python::populateIRCore(py::module &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - py::arg("locations"), py::arg("metadata") = py::none(), - py::arg("context") = py::none(), kContextGetFusedLocationDocstring) + nb::arg("locations"), nb::arg("metadata").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetFusedLocationDocstring) .def_static( "name", [](std::string name, std::optional childLoc, @@ -2856,21 +2892,22 @@ void mlir::python::populateIRCore(py::module &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - py::arg("name"), py::arg("childLoc") = py::none(), - py::arg("context") = py::none(), kContextGetNameLocationDocString) + nb::arg("name"), nb::arg("childLoc").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetNameLocationDocString) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - py::arg("attribute"), py::arg("context") = py::none(), + nb::arg("attribute"), nb::arg("context").none() = nb::none(), "Gets a Location from a LocationAttr") - .def_property_readonly( + .def_prop_ro( "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") - .def_property_readonly( + .def_prop_ro( "attr", [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") @@ -2879,7 +2916,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - py::arg("message"), "Emits an error at this location") + nb::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), @@ -2890,8 +2927,8 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- - py::class_(m, "Module", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + nb::class_(m, "Module", nb::is_weak_referenceable()) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", @@ -2903,7 +2940,19 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "parse", + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) .def_static( "create", @@ -2911,12 +2960,12 @@ void mlir::python::populateIRCore(py::module &m) { MlirModule module = mlirModuleCreateEmpty(loc); return PyModule::forModule(module).releaseObject(); }, - py::arg("loc") = py::none(), "Creates an empty module") - .def_property_readonly( + nb::arg("loc").none() = nb::none(), "Creates an empty module") + .def_prop_ro( "context", [](PyModule &self) { return self.getContext().getObject(); }, "Context that created the Module") - .def_property_readonly( + .def_prop_ro( "operation", [](PyModule &self) { return PyOperation::forOperation(self.getContext(), @@ -2925,7 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) { .releaseObject(); }, "Accesses the module as an operation") - .def_property_readonly( + .def_prop_ro( "body", [](PyModule &self) { PyOperationRef moduleOp = PyOperation::forOperation( @@ -2943,7 +2992,7 @@ void mlir::python::populateIRCore(py::module &m) { kDumpDocstring) .def( "__str__", - [](py::object self) { + [](nb::object self) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, @@ -2952,27 +3001,26 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "_OperationBase", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) + nb::class_(m, "_OperationBase") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); }) .def("__eq__", - [](PyOperationBase &self, py::object other) { return false; }) + [](PyOperationBase &self, nb::object other) { return false; }) .def("__hash__", [](PyOperationBase &self) { return static_cast(llvm::hash_value(&self.getOperation())); }) - .def_property_readonly("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }) + .def_prop_ro( "context", [](PyOperationBase &self) { PyOperation &concreteOperation = self.getOperation(); @@ -2980,46 +3028,44 @@ void mlir::python::populateIRCore(py::module &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_property_readonly("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = - concreteOperation.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) - .def_property_readonly("operands", - [](PyOperationBase &self) { - return PyOpOperandList( - self.getOperation().getRef()); - }) - .def_property_readonly("regions", - [](PyOperationBase &self) { - return PyRegionList( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + MlirStringRef name = + mlirIdentifierStr(mlirOperationGetName(operation)); + return nb::str(name.data, name.length); + }) + .def_prop_ro("operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }) + .def_prop_ro("regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }) + .def_prop_ro( "results", [](PyOperationBase &self) { return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") - .def_property_readonly( + .def_prop_ro( "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw py::value_error( + throw nb::value_error( (Twine("Cannot call .result on operation ") + StringRef(name.data, name.length) + " which has " + Twine(numResults) + " results (it is only valid for operations with a " "single result)") - .str()); + .str() + .c_str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)) @@ -3027,7 +3073,7 @@ void mlir::python::populateIRCore(py::module &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_property_readonly( + .def_prop_ro( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); @@ -3036,14 +3082,13 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the source location the operation was defined or derived " "from.") - .def_property_readonly("parent", - [](PyOperationBase &self) -> py::object { - auto parent = - self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return py::none(); - }) + .def_prop_ro("parent", + [](PyOperationBase &self) -> nb::object { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return nb::none(); + }) .def( "__str__", [](PyOperationBase &self) { @@ -3058,75 +3103,76 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the assembly form of the operation.") .def("print", - py::overload_cast( + nb::overload_cast( &PyOperationBase::print), - py::arg("state"), py::arg("file") = py::none(), - py::arg("binary") = false, kOperationPrintStateDocstring) + nb::arg("state"), nb::arg("file").none() = nb::none(), + nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool, bool>( + nb::overload_cast, bool, bool, bool, bool, + bool, nb::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("file") = py::none(), - py::arg("binary") = false, py::arg("skip_regions") = false, - kOperationPrintDocstring) - .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), - py::arg("desired_version") = py::none(), + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, + nb::arg("file").none() = nb::none(), nb::arg("binary") = false, + nb::arg("skip_regions") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), + nb::arg("desired_version").none() = nb::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. - py::arg("binary") = false, - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("skip_regions") = false, + nb::arg("binary") = false, + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") - .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), + .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), "Puts self immediately after the other operation in its parent " "block.") - .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), + .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), "Puts self immediately before the other operation in its parent " "block.") .def( "clone", - [](PyOperationBase &self, py::object ip) { + [](PyOperationBase &self, nb::object ip) { return self.getOperation().clone(ip); }, - py::arg("ip") = py::none()) + nb::arg("ip").none() = nb::none()) .def( "detach_from_parent", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) - throw py::value_error("Detached operation has no parent."); + throw nb::value_error("Detached operation has no parent."); operation.detachFromParent(); return operation.createOpView(); }, "Detaches the operation from its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def("walk", &PyOperationBase::walk, py::arg("callback"), - py::arg("walk_order") = MlirWalkPostOrder); - - py::class_(m, "Operation", py::module_local()) - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("results") = py::none(), - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - py::arg("infer_type") = false, kOperationCreateDocstring) + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder); + + nb::class_(m, "Operation") + .def_static("create", &PyOperation::create, nb::arg("name"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), + nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3134,16 +3180,15 @@ void mlir::python::populateIRCore(py::module &m) { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, - py::arg("source"), py::kw_only(), py::arg("source_name") = "", - py::arg("context") = py::none(), + nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", + nb::arg("context").none() = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyOperation::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_property_readonly("operation", [](py::object self) { return self; }) - .def_property_readonly("opview", &PyOperation::createOpView) - .def_property_readonly( + .def_prop_ro("operation", [](nb::object self) { return self; }) + .def_prop_ro("opview", &PyOperation::createOpView) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); @@ -3151,30 +3196,33 @@ void mlir::python::populateIRCore(py::module &m) { "Returns the list of Operation successors."); auto opViewClass = - py::class_(m, "OpView", py::module_local()) - .def(py::init(), py::arg("operation")) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly("opview", [](py::object self) { return self; }) + nb::class_(m, "OpView") + .def(nb::init(), nb::arg("operation")) + .def_prop_ro("operation", &PyOpView::getOperationObject) + .def_prop_ro("opview", [](nb::object self) { return self; }) .def( "__str__", - [](PyOpView &self) { return py::str(self.getOperationObject()); }) - .def_property_readonly( + [](PyOpView &self) { return nb::str(self.getOperationObject()); }) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, "Returns the list of Operation successors."); - opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); - opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); - opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); + opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), - py::arg("operands") = py::none(), py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = py::none(), - py::arg("loc") = py::none(), py::arg("ip") = py::none(), + &PyOpView::buildGeneric, nb::arg("cls"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( - [](const py::object &cls, const std::string &sourceStr, + [](const nb::object &cls, const std::string &sourceStr, const std::string &sourceName, DefaultingPyMlirContext context) { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3185,30 +3233,30 @@ void mlir::python::populateIRCore(py::module &m) { // `OpView` subclasses, and is not intended to be used on `OpView` // directly. std::string clsOpName = - py::cast(cls.attr("OPERATION_NAME")); + nb::cast(cls.attr("OPERATION_NAME")); MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); std::string_view parsedOpName(identifier.data, identifier.length); if (clsOpName != parsedOpName) throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + parsedOpName + "'"); - return PyOpView::constructDerived(cls, *parsed.get()); + return PyOpView::constructDerived(cls, parsed.getObject()); }, - py::arg("cls"), py::arg("source"), py::kw_only(), - py::arg("source_name") = "", py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("source"), nb::kw_only(), + nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. //---------------------------------------------------------------------------- - py::class_(m, "Region", py::module_local()) - .def_property_readonly( + nb::class_(m, "Region") + .def_prop_ro( "blocks", [](PyRegion &self) { return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") - .def_property_readonly( + .def_prop_ro( "owner", [](PyRegion &self) { return self.getParentOperation()->createOpView(); @@ -3226,27 +3274,27 @@ void mlir::python::populateIRCore(py::module &m) { [](PyRegion &self, PyRegion &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); + .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- - py::class_(m, "Block", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) - .def_property_readonly( + nb::class_(m, "Block") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro( "owner", [](PyBlock &self) { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") - .def_property_readonly( + .def_prop_ro( "region", [](PyBlock &self) { MlirRegion region = mlirBlockGetParentRegion(self.get()); return PyRegion(self.getParentOperation(), region); }, "Returns the owning region of this block.") - .def_property_readonly( + .def_prop_ro( "arguments", [](PyBlock &self) { return PyBlockArgumentList(self.getParentOperation(), self.get()); @@ -3265,7 +3313,7 @@ void mlir::python::populateIRCore(py::module &m) { return mlirBlockEraseArgument(self.get(), index); }, "Erase the argument at 'index' and remove it from the argument list.") - .def_property_readonly( + .def_prop_ro( "operations", [](PyBlock &self) { return PyOperationList(self.getParentOperation(), self.get()); @@ -3273,15 +3321,15 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, const py::list &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyRegion &parent, const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - py::arg("parent"), py::arg("arg_types") = py::list(), - py::arg("arg_locs") = std::nullopt, + nb::arg("parent"), nb::arg("arg_types") = nb::list(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " "region (with given argument types and locations).") .def( @@ -3295,28 +3343,32 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " "(with given argument types and locations).") .def( @@ -3333,7 +3385,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyBlock &self, PyBlock &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) + .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) .def("__hash__", [](PyBlock &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3359,7 +3411,7 @@ void mlir::python::populateIRCore(py::module &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, - py::arg("operation"), + nb::arg("operation"), "Appends an operation to this block. If the operation is currently " "in another block, it will be moved."); @@ -3367,39 +3419,41 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- - py::class_(m, "InsertionPoint", py::module_local()) - .def(py::init(), py::arg("block"), + nb::class_(m, "InsertionPoint") + .def(nb::init(), nb::arg("block"), "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) - .def("__exit__", &PyInsertionPoint::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyInsertionPoint::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw py::value_error("No current InsertionPoint"); + throw nb::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), + .def(nb::init(), nb::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - py::arg("block"), "Inserts at the beginning of the block.") + nb::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - py::arg("block"), "Inserts before the block terminator.") - .def("insert", &PyInsertionPoint::insert, py::arg("operation"), + nb::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), "Inserts an operation.") - .def_property_readonly( + .def_prop_ro( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, "Returns the block that this InsertionPoint points to.") - .def_property_readonly( + .def_prop_ro( "ref_operation", - [](PyInsertionPoint &self) -> py::object { + [](PyInsertionPoint &self) -> nb::object { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return py::none(); + return nb::none(); }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3408,13 +3462,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - py::class_(m, "Attribute", py::module_local()) + nb::class_(m, "Attribute") // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAttribute::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", @@ -3426,24 +3479,24 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse attribute", errors.take()); return attr; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " "failure.") - .def_property_readonly( + .def_prop_ro( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly( - "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) + .def_prop_ro("type", + [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - py::keep_alive<0, 1>(), "Binds a name to the attribute") + nb::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) + .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) .def("__hash__", [](PyAttribute &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3474,36 +3527,35 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( - "typeid", - [](PyAttribute &self) -> MlirTypeID { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return mlirTypeID; - }) + .def_prop_ro("typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirAttributeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- - py::class_(m, "NamedAttribute", py::module_local()) + nb::class_(m, "NamedAttribute") .def("__repr__", [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - py::str(mlirIdentifierStr(self.namedAttr.name).data, + nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, @@ -3512,28 +3564,28 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( + .def_prop_ro( "name", [](PyNamedAttribute &self) { - return py::str(mlirIdentifierStr(self.namedAttr.name).data, + return nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length); }, "The name of the NamedAttribute binding") - .def_property_readonly( + .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, - py::keep_alive<0, 1>(), + nb::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - py::class_(m, "Type", py::module_local()) + nb::class_(m, "Type") // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed type to the generic Type") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", @@ -3545,13 +3597,15 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse type", errors.take()); return type; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), kContextParseTypeDocstring) - .def_property_readonly( + .def_prop_ro( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) - .def("__eq__", [](PyType &self, py::object &other) { return false; }) + .def( + "__eq__", [](PyType &self, nb::object &other) { return false; }, + nb::arg("other").none()) .def("__hash__", [](PyType &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3585,28 +3639,27 @@ void mlir::python::populateIRCore(py::module &m) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirTypeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }) - .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { + .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) return mlirTypeID; - auto origRepr = - pybind11::repr(pybind11::cast(self)).cast(); - throw py::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str()); + auto origRepr = nb::cast(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); }); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- - py::class_(m, "TypeID", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + nb::class_(m, "TypeID") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether @@ -3614,7 +3667,7 @@ void mlir::python::populateIRCore(py::module &m) { .def("__eq__", [](PyTypeID &self, PyTypeID &other) { return self == other; }) .def("__eq__", - [](PyTypeID &self, const py::object &other) { return false; }) + [](PyTypeID &self, const nb::object &other) { return false; }) // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. @@ -3625,20 +3678,20 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - py::class_(m, "Value", py::module_local()) - .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + nb::class_(m, "Value") + .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) - .def_property_readonly( + .def_prop_ro( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, "Context in which the value lives.") .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) - .def_property_readonly( + .def_prop_ro( "owner", - [](PyValue &self) -> py::object { + [](PyValue &self) -> nb::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert( @@ -3651,22 +3704,22 @@ void mlir::python::populateIRCore(py::module &m) { if (mlirValueIsABlockArgument(v)) { MlirBlock block = mlirBlockArgumentGetOwner(self.get()); - return py::cast(PyBlock(self.getParentOperation(), block)); + return nb::cast(PyBlock(self.getParentOperation(), block)); } assert(false && "Value must be a block argument or an op result"); - return py::none(); + return nb::none(); }) - .def_property_readonly("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) + .def_prop_ro("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyValue &self, py::object other) { return false; }) + .def("__eq__", [](PyValue &self, nb::object other) { return false; }) .def("__hash__", [](PyValue &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3698,26 +3751,26 @@ void mlir::python::populateIRCore(py::module &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - py::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false) .def( "get_name", - [](PyValue &self, std::reference_wrapper state) { + [](PyValue &self, PyAsmState &state) { PyPrintAccumulator printAccum; - MlirAsmState valueState = state.get().get(); + MlirAsmState valueState = state.get(); mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, - py::arg("state"), kGetNameAsOperand) - .def_property_readonly( - "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) + nb::arg("state"), kGetNameAsOperand) + .def_prop_ro("type", + [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "set_type", [](PyValue &self, const PyType &type) { return mlirValueSetType(self.get(), type); }, - py::arg("type")) + nb::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { @@ -3730,22 +3783,22 @@ void mlir::python::populateIRCore(py::module &m) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, py::list exceptions) { + [](MlirValue self, MlirValue with, nb::list exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector exceptionOps; - for (py::handle exception : exceptions) { - exceptionOps.push_back(exception.cast().get()); + for (nb::handle exception : exceptions) { + exceptionOps.push_back(nb::cast(exception).get()); } mlirValueReplaceAllUsesExcept( self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) { return self.maybeDownCast(); }); @@ -3753,20 +3806,20 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); - py::class_(m, "AsmState", py::module_local()) - .def(py::init(), py::arg("value"), - py::arg("use_local_scope") = false) - .def(py::init(), py::arg("op"), - py::arg("use_local_scope") = false); + nb::class_(m, "AsmState") + .def(nb::init(), nb::arg("value"), + nb::arg("use_local_scope") = false) + .def(nb::init(), nb::arg("op"), + nb::arg("use_local_scope") = false); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- - py::class_(m, "SymbolTable", py::module_local()) - .def(py::init()) + nb::class_(m, "SymbolTable") + .def(nb::init()) .def("__getitem__", &PySymbolTable::dunderGetItem) - .def("insert", &PySymbolTable::insert, py::arg("operation")) - .def("erase", &PySymbolTable::erase, py::arg("operation")) + .def("insert", &PySymbolTable::insert, nb::arg("operation")) + .def("erase", &PySymbolTable::erase, nb::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) .def("__contains__", [](PySymbolTable &table, const std::string &name) { @@ -3775,19 +3828,19 @@ void mlir::python::populateIRCore(py::module &m) { }) // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - py::arg("symbol"), py::arg("name")) + nb::arg("symbol"), nb::arg("name")) .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - py::arg("symbol")) + nb::arg("symbol")) .def_static("get_visibility", &PySymbolTable::getVisibility, - py::arg("symbol")) + nb::arg("symbol")) .def_static("set_visibility", &PySymbolTable::setVisibility, - py::arg("symbol"), py::arg("visibility")) + nb::arg("symbol"), nb::arg("visibility")) .def_static("replace_all_symbol_uses", - &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), - py::arg("new_symbol"), py::arg("from_op")) + &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), + nb::arg("new_symbol"), nb::arg("from_op")) .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, - py::arg("from_op"), py::arg("all_sym_uses_visible"), - py::arg("callback")); + nb::arg("from_op"), nb::arg("all_sym_uses_visible"), + nb::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); @@ -3809,14 +3862,15 @@ void mlir::python::populateIRCore(py::module &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); - py::register_local_exception_translator([](std::exception_ptr p) { + nb::register_exception_translator([](const std::exception_ptr &p, + void *payload) { // We can't define exceptions with custom fields through pybind, so instead // the exception class is defined in python and imported here. try { if (p) std::rethrow_exception(p); } catch (const MLIRError &e) { - py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("MLIRError")(e.message, e.errorDiagnostics); PyErr_SetObject(PyExc_Exception, obj.ptr()); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 54cfa5606..c339a93e3 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include #include #include #include @@ -24,7 +24,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -namespace py = pybind11; +namespace nb = nanobind; namespace mlir { namespace python { @@ -53,10 +53,10 @@ namespace { /// Takes in an optional ist of operands and converts them into a SmallVector /// of MlirVlaues. Returns an empty SmallVector if the list is empty. -llvm::SmallVector wrapOperands(std::optional operandList) { +llvm::SmallVector wrapOperands(std::optional operandList) { llvm::SmallVector mlirOperands; - if (!operandList || operandList->empty()) { + if (!operandList || operandList->size() == 0) { return mlirOperands; } @@ -68,40 +68,42 @@ llvm::SmallVector wrapOperands(std::optional operandList) { PyValue *val; try { - val = py::cast(it.value()); + val = nb::cast(it.value()); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); continue; - } catch (py::cast_error &err) { + } catch (nb::cast_error &err) { // Intentionally unhandled to try sequence below first. (void)err; } try { - auto vals = py::cast(it.value()); - for (py::object v : vals) { + auto vals = nb::cast(it.value()); + for (nb::handle v : vals) { try { - val = py::cast(v); + val = nb::cast(v); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); - } catch (py::cast_error &err) { - throw py::value_error( + } catch (nb::cast_error &err) { + throw nb::value_error( (llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } continue; - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } - throw py::cast_error(); + throw nb::cast_error(); } return mlirOperands; @@ -144,24 +146,24 @@ wrapRegions(std::optional> regions) { template class PyConcreteOpInterface { protected: - using ClassTy = py::class_; + using ClassTy = nb::class_; using GetTypeIDFunctionTy = MlirTypeID (*)(); public: /// Constructs an interface instance from an object that is either an /// operation or a subclass of OpView. In the latter case, only the static /// methods of the interface are accessible to the caller. - PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) : obj(std::move(object)) { try { - operation = &py::cast(obj); - } catch (py::cast_error &) { + operation = &nb::cast(obj); + } catch (nb::cast_error &) { // Do nothing. } try { - operation = &py::cast(obj).getOperation(); - } catch (py::cast_error &) { + operation = &nb::cast(obj).getOperation(); + } catch (nb::cast_error &) { // Do nothing. } @@ -169,7 +171,7 @@ class PyConcreteOpInterface { if (!mlirOperationImplementsInterface(*operation, ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } MlirIdentifier identifier = mlirOperationGetName(*operation); @@ -177,9 +179,9 @@ class PyConcreteOpInterface { opName = std::string(stringRef.data, stringRef.length); } else { try { - opName = obj.attr("OPERATION_NAME").template cast(); - } catch (py::cast_error &) { - throw py::type_error( + opName = nb::cast(obj.attr("OPERATION_NAME")); + } catch (nb::cast_error &) { + throw nb::type_error( "Op interface does not refer to an operation or OpView class"); } @@ -187,22 +189,19 @@ class PyConcreteOpInterface { mlirStringRefCreate(opName.data(), opName.length()), context.resolve().get(), ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } } } /// Creates the Python bindings for this class in the given module. - static void bind(py::module &m) { - py::class_ cls(m, ConcreteIface::pyClassName, - py::module_local()); - cls.def(py::init(), py::arg("object"), - py::arg("context") = py::none(), constructorDoc) - .def_property_readonly("operation", - &PyConcreteOpInterface::getOperationObject, - operationDoc) - .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, - opviewDoc); + static void bind(nb::module_ &m) { + nb::class_ cls(m, ConcreteIface::pyClassName); + cls.def(nb::init(), nb::arg("object"), + nb::arg("context").none() = nb::none(), constructorDoc) + .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); ConcreteIface::bindDerived(cls); } @@ -216,9 +215,9 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - py::object getOperationObject() { + nb::object getOperationObject() { if (operation == nullptr) { - throw py::type_error("Cannot get an operation from a static interface"); + throw nb::type_error("Cannot get an operation from a static interface"); } return operation->getRef().releaseObject(); @@ -227,9 +226,9 @@ class PyConcreteOpInterface { /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - py::object getOpView() { + nb::object getOpView() { if (operation == nullptr) { - throw py::type_error("Cannot get an opview from a static interface"); + throw nb::type_error("Cannot get an opview from a static interface"); } return operation->createOpView(); @@ -242,7 +241,7 @@ class PyConcreteOpInterface { private: PyOperation *operation = nullptr; std::string opName; - py::object obj; + nb::object obj; }; /// Python wrapper for InferTypeOpInterface. This interface has only static @@ -276,7 +275,7 @@ class PyInferTypeOpInterface /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional operandList, + inferReturnTypes(std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, @@ -299,7 +298,7 @@ class PyInferTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result types"); + throw nb::value_error("Failed to infer result types"); } return inferredTypes; @@ -307,11 +306,12 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("properties") = py::none(), py::arg("regions") = py::none(), - py::arg("context") = py::none(), py::arg("loc") = py::none(), - inferReturnTypesDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); } }; @@ -319,9 +319,9 @@ class PyInferTypeOpInterface class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} - PyShapedTypeComponents(py::list shape, MlirType elementType) + PyShapedTypeComponents(nb::list shape, MlirType elementType) : shape(std::move(shape)), elementType(elementType), ranked(true) {} - PyShapedTypeComponents(py::list shape, MlirType elementType, + PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute) : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} @@ -330,10 +330,9 @@ class PyShapedTypeComponents { : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} - static void bind(py::module &m) { - py::class_(m, "ShapedTypeComponents", - py::module_local()) - .def_property_readonly( + static void bind(nb::module_ &m) { + nb::class_(m, "ShapedTypeComponents") + .def_prop_ro( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") @@ -342,57 +341,57 @@ class PyShapedTypeComponents { [](PyType &elementType) { return PyShapedTypeComponents(elementType); }, - py::arg("element_type"), + nb::arg("element_type"), "Create an shaped type components object with only the element " "type.") .def_static( "get", - [](py::list shape, PyType &elementType) { + [](nb::list shape, PyType &elementType) { return PyShapedTypeComponents(std::move(shape), elementType); }, - py::arg("shape"), py::arg("element_type"), + nb::arg("shape"), nb::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", - [](py::list shape, PyType &elementType, PyAttribute &attribute) { + [](nb::list shape, PyType &elementType, PyAttribute &attribute) { return PyShapedTypeComponents(std::move(shape), elementType, attribute); }, - py::arg("shape"), py::arg("element_type"), py::arg("attribute"), + nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), "Create a ranked shaped type components object with attribute.") - .def_property_readonly( + .def_prop_ro( "has_rank", [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, "Returns whether the given shaped type component is ranked.") - .def_property_readonly( + .def_prop_ro( "rank", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::int_(self.shape.size()); + return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " "the shaped type components does not have a rank, None is " "returned.") - .def_property_readonly( + .def_prop_ro( "shape", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::list(self.shape); + return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " "of integers. Returns none if the shaped type component does not " "have a rank."); } - pybind11::object getCapsule(); - static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + nb::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(nb::object capsule); private: - py::list shape; + nb::list shape; MlirType elementType; MlirAttribute attribute; bool ranked{false}; @@ -424,7 +423,7 @@ class PyInferShapedTypeOpInterface if (!hasRank) { data->inferredShapedTypeComponents.emplace_back(elementType); } else { - py::list shapeList; + nb::list shapeList; for (intptr_t i = 0; i < rank; ++i) { shapeList.append(shape[i]); } @@ -436,7 +435,7 @@ class PyInferShapedTypeOpInterface /// Given the arguments required to build an operation, attempts to infer the /// shaped type components. Throws value_error on failure. std::vector inferReturnTypeComponents( - std::optional operandList, + std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { @@ -458,7 +457,7 @@ class PyInferShapedTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result shape type components"); + throw nb::value_error("Failed to infer result shape type components"); } return inferredShapedTypeComponents; @@ -467,14 +466,16 @@ class PyInferShapedTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), py::arg("regions") = py::none(), - py::arg("properties") = py::none(), py::arg("context") = py::none(), - py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); } }; -void populateIRInterfaces(py::module &m) { +void populateIRInterfaces(nb::module_ &m) { PyInferTypeOpInterface::bind(m); PyShapedTypeComponents::bind(m); PyInferShapedTypeOpInterface::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6727860c0..416a14218 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -7,16 +7,19 @@ //===----------------------------------------------------------------------===// #include "IRModule.h" -#include "Globals.h" -#include "PybindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Support.h" +#include +#include #include #include -namespace py = pybind11; +#include "Globals.h" +#include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" + +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded = py::none(); + nb::object loaded = nb::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { + loaded = nb::module_::import_(moduleName.c_str()); + } catch (nb::python_error &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; } @@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc, bool replace) { - py::object &found = attributeBuilderMap[attributeKind]; + nb::callable pyFunc, bool replace) { + nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + - py::str(found).operator std::string()) + nb::cast(nb::str(found))) .str()); } found = std::move(pyFunc); } void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, - pybind11::function typeCaster, - bool replace) { - pybind11::object &found = typeCasterMap[mlirTypeID]; + nb::callable typeCaster, bool replace) { + nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + - py::str(found).operator std::string()); + nb::cast(nb::str(found))); found = std::move(typeCaster); } void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, - bool replace) { - pybind11::object &found = valueCasterMap[mlirTypeID]; + nb::callable valueCaster, bool replace) { + nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + - py::repr(found).cast()); + nb::cast(nb::repr(found))); found = std::move(valueCaster); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::object &found = dialectClassMap[dialectNamespace]; + nb::object pyClass) { + nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + dialectNamespace + "' is already registered.") @@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, bool replace) { - py::object &found = operationClassMap[operationName]; + nb::object pyClass, bool replace) { + nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") @@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, found = std::move(pyClass); } -std::optional +std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { @@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } -std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -142,7 +143,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -154,7 +155,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional +std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) @@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { return std::nullopt; } -std::optional +std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Make sure dialect module is loaded. auto split = operationName.split('.'); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 172898cfd..a242ff26b 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -10,20 +10,22 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include +#include + #include #include #include #include "Globals.h" -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -49,7 +51,7 @@ class PyValue; template class PyObjectRef { public: - PyObjectRef(T *referrent, pybind11::object object) + PyObjectRef(T *referrent, nanobind::object object) : referrent(referrent), object(std::move(object)) { assert(this->referrent && "cannot construct PyObjectRef with null referrent"); @@ -67,13 +69,13 @@ class PyObjectRef { int getRefCount() { if (!object) return 0; - return object.ref_count(); + return Py_REFCNT(object.ptr()); } /// Releases the object held by this instance, returning it. /// This is the proper thing to return from a function that wants to return /// the reference. Note that this does not work from initializers. - pybind11::object releaseObject() { + nanobind::object releaseObject() { assert(referrent && object); referrent = nullptr; auto stolen = std::move(object); @@ -85,7 +87,7 @@ class PyObjectRef { assert(referrent && object); return referrent; } - pybind11::object getObject() { + nanobind::object getObject() { assert(referrent && object); return object; } @@ -93,7 +95,7 @@ class PyObjectRef { private: T *referrent; - pybind11::object object; + nanobind::object object; }; /// Tracks an entry in the thread context stack. New entries are pushed onto @@ -112,9 +114,9 @@ class PyThreadContextEntry { Location, }; - PyThreadContextEntry(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, - pybind11::object location) + PyThreadContextEntry(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, + nanobind::object location) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), location(std::move(location)), frameKind(frameKind) {} @@ -134,26 +136,26 @@ class PyThreadContextEntry { /// Stack management. static PyThreadContextEntry *getTopOfStack(); - static pybind11::object pushContext(PyMlirContext &context); + static nanobind::object pushContext(nanobind::object context); static void popContext(PyMlirContext &context); - static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); + static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); static void popInsertionPoint(PyInsertionPoint &insertionPoint); - static pybind11::object pushLocation(PyLocation &location); + static nanobind::object pushLocation(nanobind::object location); static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: - static void push(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, pybind11::object location); + static void push(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, nanobind::object location); /// An object reference to the PyContext. - pybind11::object context; + nanobind::object context; /// An object reference to the current insertion point. - pybind11::object insertionPoint; + nanobind::object insertionPoint; /// An object reference to the current location. - pybind11::object location; + nanobind::object location; // The kind of push that was performed. FrameKind frameKind; }; @@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; + PyMlirContext(MlirContext context); PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (py::init) method, pybind11 is quite - /// strict about needing to return a pointer that is not yet associated to - /// an py::object. Since the forContext() method acts like a pool, possibly - /// returning a recycled context, it does not satisfy this need. The usual - /// way in python to accomplish such a thing is to override __new__, but + /// For the case of a python __init__ (nanobind::init) method, pybind11 is + /// quite strict about needing to return a pointer that is not yet associated + /// to an nanobind::object. Since the forContext() method acts like a pool, + /// possibly returning a recycled context, it does not satisfy this need. The + /// usual way in python to accomplish such a thing is to override __new__, but /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). @@ -187,17 +190,17 @@ class PyMlirContext { /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { - return PyMlirContextRef(this, pybind11::cast(this)); + return PyMlirContextRef(this, nanobind::cast(this)); } /// Gets a capsule wrapping the void* within the MlirContext. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -237,14 +240,14 @@ class PyMlirContext { size_t getLiveModuleCount(); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object context); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Attaches a Python callback as a diagnostic handler, returning a /// registration object (internally a PyDiagnosticHandler). - pybind11::object attachDiagnosticHandler(pybind11::object callback); + nanobind::object attachDiagnosticHandler(nanobind::object callback); /// Controls whether error diagnostics should be propagated to diagnostic /// handlers, instead of being captured by `ErrorCapture`. @@ -252,8 +255,6 @@ class PyMlirContext { struct ErrorCapture; private: - PyMlirContext(MlirContext context); - // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an @@ -268,7 +269,7 @@ class PyMlirContext { // from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveModuleMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveModuleMap liveModules; // Interns all live operations associated with this context. Operations @@ -276,7 +277,7 @@ class PyMlirContext { // removed from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveOperationMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveOperationMap liveOperations; bool emitErrorDiagnostics = false; @@ -324,19 +325,19 @@ class PyLocation : public BaseContextObject { MlirLocation get() const { return loc; } /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object location); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyLocation from the MlirLocation wrapped by a capsule. /// Note that PyLocation instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirLocation /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); + static PyLocation createFromCapsule(nanobind::object capsule); private: MlirLocation loc; @@ -353,8 +354,8 @@ class PyDiagnostic { bool isValid() { return valid; } MlirDiagnosticSeverity getSeverity(); PyLocation getLocation(); - pybind11::str getMessage(); - pybind11::tuple getNotes(); + nanobind::str getMessage(); + nanobind::tuple getNotes(); /// Materialized diagnostic information. This is safe to access outside the /// diagnostic callback. @@ -373,7 +374,7 @@ class PyDiagnostic { /// If notes have been materialized from the diagnostic, then this will /// be populated with the corresponding objects (all castable to /// PyDiagnostic). - std::optional materializedNotes; + std::optional materializedNotes; bool valid = true; }; @@ -398,7 +399,7 @@ class PyDiagnostic { /// is no way to attach an existing handler object). class PyDiagnosticHandler { public: - PyDiagnosticHandler(MlirContext context, pybind11::object callback); + PyDiagnosticHandler(MlirContext context, nanobind::object callback); ~PyDiagnosticHandler(); bool isAttached() { return registeredID.has_value(); } @@ -407,16 +408,16 @@ class PyDiagnosticHandler { /// Detaches the handler. Does nothing if not attached. void detach(); - pybind11::object contextEnter() { return pybind11::cast(this); } - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { + nanobind::object contextEnter() { return nanobind::cast(this); } + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb) { detach(); } private: MlirContext context; - pybind11::object callback; + nanobind::object callback; std::optional registeredID; bool hadError = false; friend class PyMlirContext; @@ -477,12 +478,12 @@ class PyDialects : public BaseContextObject { /// objects of this type will be returned directly. class PyDialect { public: - PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} + PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} - pybind11::object getDescriptor() { return descriptor; } + nanobind::object getDescriptor() { return descriptor; } private: - pybind11::object descriptor; + nanobind::object descriptor; }; /// Wrapper around an MlirDialectRegistry. @@ -505,8 +506,8 @@ class PyDialectRegistry { operator MlirDialectRegistry() const { return registry; } MlirDialectRegistry get() const { return registry; } - pybind11::object getCapsule(); - static PyDialectRegistry createFromCapsule(pybind11::object capsule); + nanobind::object getCapsule(); + static PyDialectRegistry createFromCapsule(nanobind::object capsule); private: MlirDialectRegistry registry; @@ -542,26 +543,25 @@ class PyModule : public BaseContextObject { /// Gets a strong reference to this module. PyModuleRef getRef() { - return PyModuleRef(this, - pybind11::reinterpret_borrow(handle)); + return PyModuleRef(this, nanobind::borrow(handle)); } /// Gets a capsule wrapping the void* within the MlirModule. /// Note that the module does not (yet) provide a corresponding factory for /// constructing from a capsule as that would require uniquing PyModule /// instances, which is not currently done. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. /// Note that PyModule instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirModule /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; - pybind11::handle handle; + nanobind::handle handle; }; class PyAsmState; @@ -574,18 +574,18 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, bool binary, + bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions); - void print(PyAsmState &state, py::object fileObject, bool binary); + void print(PyAsmState &state, nanobind::object fileObject, bool binary); - pybind11::object getAsm(bool binary, + nanobind::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. - void writeBytecode(const pybind11::object &fileObject, + void writeBytecode(const nanobind::object &fileObject, std::optional bytecodeVersion); // Implement the walk method. @@ -621,13 +621,13 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Parses a source string (either text assembly or bytecode), creating a /// detached operation. @@ -640,7 +640,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void detachFromParent() { mlirOperationRemoveFromParent(getOperation()); setDetached(); - parentKeepAlive = pybind11::object(); + parentKeepAlive = nanobind::object(); } /// Gets the backing operation. @@ -651,12 +651,11 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } PyOperationRef getRef() { - return PyOperationRef( - this, pybind11::reinterpret_borrow(handle)); + return PyOperationRef(this, nanobind::borrow(handle)); } bool isAttached() { return attached; } - void setAttached(const pybind11::object &parent = pybind11::object()) { + void setAttached(const nanobind::object &parent = nanobind::object()) { assert(!attached && "operation already attached"); attached = true; } @@ -675,24 +674,24 @@ class PyOperation : public PyOperationBase, public BaseContextObject { std::optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Creates an operation. See corresponding python docstring. - static pybind11::object + static nanobind::object create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip, + DefaultingPyLocation location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. - pybind11::object createOpView(); + nanobind::object createOpView(); /// Erases the underlying MlirOperation, removes its pointer from the /// parent context's live operations map, and sets the valid bit false. @@ -702,23 +701,23 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void setInvalid() { valid = false; } /// Clones this operation. - pybind11::object clone(const pybind11::object &ip); + nanobind::object clone(const nanobind::object &ip); private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive); + nanobind::object parentKeepAlive); MlirOperation operation; - pybind11::handle handle; + nanobind::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or // Module. // TODO: As implemented, this facility is only sufficient for modeling the // trivial module parent back-reference. Generalize this to also account for // transitions from detached to attached and address TODOs in the // ir_operation.py regarding testing corresponding lifetime guarantees. - pybind11::object parentKeepAlive; + nanobind::object parentKeepAlive; bool attached = true; bool valid = true; @@ -733,17 +732,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// python types. class PyOpView : public PyOperationBase { public: - PyOpView(const pybind11::object &operationObject); + PyOpView(const nanobind::object &operationObject); PyOperation &getOperation() override { return operation; } - pybind11::object getOperationObject() { return operationObject; } + nanobind::object getOperationObject() { return operationObject; } - static pybind11::object buildGeneric( - const pybind11::object &cls, std::optional resultTypeList, - pybind11::list operandList, std::optional attributes, + static nanobind::object buildGeneric( + const nanobind::object &cls, std::optional resultTypeList, + nanobind::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor @@ -752,12 +751,12 @@ class PyOpView : public PyOperationBase { /// /// The caller is responsible for verifying that `operation` is a valid /// operation to construct `cls` with. - static pybind11::object constructDerived(const pybind11::object &cls, - const PyOperation &operation); + static nanobind::object constructDerived(const nanobind::object &cls, + const nanobind::object &operation); private: PyOperation &operation; // For efficient, cast-free access from C++ - pybind11::object operationObject; // Holds the reference. + nanobind::object operationObject; // Holds the reference. }; /// Wrapper around an MlirRegion. @@ -830,7 +829,7 @@ class PyBlock { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirBlock. - pybind11::object getCapsule(); + nanobind::object getCapsule(); private: PyOperationRef parentOperation; @@ -858,10 +857,10 @@ class PyInsertionPoint { void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object insertionPoint); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); PyBlock &getBlock() { return block; } std::optional &getRefOperation() { return refOperation; } @@ -886,13 +885,13 @@ class PyType : public BaseContextObject { MlirType get() const { return type; } /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyType from the MlirType wrapped by a capsule. /// Note that PyType instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirType /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); + static PyType createFromCapsule(nanobind::object capsule); private: MlirType type; @@ -912,10 +911,10 @@ class PyTypeID { MlirTypeID get() { return typeID; } /// Gets a capsule wrapping the void* within the MlirTypeID. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. - static PyTypeID createFromCapsule(pybind11::object capsule); + static PyTypeID createFromCapsule(nanobind::object capsule); private: MlirTypeID typeID; @@ -932,7 +931,7 @@ class PyConcreteType : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirType); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -945,34 +944,38 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_type")); + static void bind(nanobind::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_type")); cls.def_static( "isinstance", [](PyType &otherType) -> bool { return DerivedTy::isaFunction(otherType); }, - pybind11::arg("other")); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + nanobind::arg("other")); + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyType &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyType &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -986,8 +989,8 @@ class PyConcreteType : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function( - [](PyType pyType) -> DerivedTy { return pyType; })); + nanobind::cast(nanobind::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; }))); } DerivedTy::bindDerived(cls); @@ -1008,13 +1011,13 @@ class PyAttribute : public BaseContextObject { MlirAttribute get() const { return attr; } /// Gets a capsule wrapping the void* within the MlirAttribute. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(pybind11::object capsule); + static PyAttribute createFromCapsule(nanobind::object capsule); private: MlirAttribute attr; @@ -1054,7 +1057,7 @@ class PyConcreteAttribute : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirAttribute); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -1067,37 +1070,45 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), - pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_attr")); + static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { + ClassTy cls; + if (slots) { + cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); + } else { + cls = ClassTy(m, DerivedTy::pyClassName); + } + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_attr")); cls.def_static( "isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }, - pybind11::arg("other")); - cls.def_property_readonly( + nanobind::arg("other")); + cls.def_prop_ro( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyAttribute &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyAttribute &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1112,9 +1123,10 @@ class PyConcreteAttribute : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { - return pyAttribute; - })); + nanobind::cast( + nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + }))); } DerivedTy::bindDerived(cls); @@ -1146,13 +1158,13 @@ class PyValue { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirValue. - pybind11::object getCapsule(); + nanobind::object getCapsule(); - pybind11::object maybeDownCast(); + nanobind::object maybeDownCast(); /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. - static PyValue createFromCapsule(pybind11::object capsule); + static PyValue createFromCapsule(nanobind::object capsule); private: PyOperationRef parentOperation; @@ -1169,13 +1181,13 @@ class PyAffineExpr : public BaseContextObject { MlirAffineExpr get() const { return affineExpr; } /// Gets a capsule wrapping the void* within the MlirAffineExpr. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(pybind11::object capsule); + static PyAffineExpr createFromCapsule(nanobind::object capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1196,13 +1208,13 @@ class PyAffineMap : public BaseContextObject { MlirAffineMap get() const { return affineMap; } /// Gets a capsule wrapping the void* within the MlirAffineMap. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(pybind11::object capsule); + static PyAffineMap createFromCapsule(nanobind::object capsule); private: MlirAffineMap affineMap; @@ -1217,12 +1229,12 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(pybind11::object capsule); + static PyIntegerSet createFromCapsule(nanobind::object capsule); private: MlirIntegerSet integerSet; @@ -1239,7 +1251,7 @@ class PySymbolTable { /// Returns the symbol (opview) with the given name, throws if there is no /// such symbol in the table. - pybind11::object dunderGetItem(const std::string &name); + nanobind::object dunderGetItem(const std::string &name); /// Removes the given operation from the symbol table and erases it. void erase(PyOperationBase &symbol); @@ -1269,7 +1281,7 @@ class PySymbolTable { /// Walks all symbol tables under and including 'from'. static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - pybind11::object callback); + nanobind::object callback); /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } @@ -1289,16 +1301,16 @@ struct MLIRError { std::vector errorDiagnostics; }; -void populateIRAffine(pybind11::module &m); -void populateIRAttributes(pybind11::module &m); -void populateIRCore(pybind11::module &m); -void populateIRInterfaces(pybind11::module &m); -void populateIRTypes(pybind11::module &m); +void populateIRAffine(nanobind::module_ &m); +void populateIRAttributes(nanobind::module_ &m); +void populateIRCore(nanobind::module_ &m); +void populateIRInterfaces(nanobind::module_ &m); +void populateIRTypes(nanobind::module_ &m); } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template <> @@ -1309,6 +1321,6 @@ struct type_caster : MlirDefaultingCaster {}; } // namespace detail -} // namespace pybind11 +} // namespace nanobind #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 6f192bc4b..5cfa51142 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,19 +6,26 @@ // //===----------------------------------------------------------------------===// +// clang-format off #include "IRModule.h" +#include "mlir/Bindings/Python/IRTypes.h" +// clang-format on -#include "PybindUtils.h" +#include +#include +#include +#include +#include -#include "mlir/Bindings/Python/IRTypes.h" +#include +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" -#include - -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -48,7 +55,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signless integer type"); c.def_static( "get_signed", @@ -56,7 +63,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeSignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signed integer type"); c.def_static( "get_unsigned", @@ -64,25 +71,25 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create an unsigned integer type"); - c.def_property_readonly( + c.def_prop_ro( "width", [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, "Returns the width of the integer type"); - c.def_property_readonly( + c.def_prop_ro( "is_signless", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSignless(self); }, "Returns whether this is a signless integer"); - c.def_property_readonly( + c.def_prop_ro( "is_signed", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, "Returns whether this is a signed integer"); - c.def_property_readonly( + c.def_prop_ro( "is_unsigned", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsUnsigned(self); @@ -107,7 +114,7 @@ class PyIndexType : public PyConcreteType { MlirType t = mlirIndexTypeGet(context->get()); return PyIndexType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a index type."); + nb::arg("context").none() = nb::none(), "Create a index type."); } }; @@ -118,7 +125,7 @@ class PyFloatType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, "Returns the width of the floating-point type"); } @@ -141,7 +148,7 @@ class PyFloat4E2M1FNType MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); return PyFloat4E2M1FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float4_e2m1fn type."); + nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); } }; @@ -162,7 +169,7 @@ class PyFloat6E2M3FNType MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); return PyFloat6E2M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e2m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); } }; @@ -183,7 +190,7 @@ class PyFloat6E3M2FNType MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); return PyFloat6E3M2FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e3m2fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); } }; @@ -204,7 +211,7 @@ class PyFloat8E4M3FNType MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); return PyFloat8E4M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); } }; @@ -224,7 +231,7 @@ class PyFloat8E5M2Type : public PyConcreteType { MlirType t = mlirFloat8E5M2TypeGet(context->get()); return PyFloat8E5M2Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); } }; @@ -244,7 +251,7 @@ class PyFloat8E4M3Type : public PyConcreteType { MlirType t = mlirFloat8E4M3TypeGet(context->get()); return PyFloat8E4M3Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); } }; @@ -265,7 +272,8 @@ class PyFloat8E4M3FNUZType MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); return PyFloat8E4M3FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3fnuz type."); } }; @@ -286,7 +294,8 @@ class PyFloat8E4M3B11FNUZType MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); return PyFloat8E4M3B11FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3b11fnuz type."); } }; @@ -307,7 +316,8 @@ class PyFloat8E5M2FNUZType MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); return PyFloat8E5M2FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e5m2fnuz type."); } }; @@ -327,7 +337,7 @@ class PyFloat8E3M4Type : public PyConcreteType { MlirType t = mlirFloat8E3M4TypeGet(context->get()); return PyFloat8E3M4Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e3m4 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); } }; @@ -348,7 +358,8 @@ class PyFloat8E8M0FNUType MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); return PyFloat8E8M0FNUType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e8m0fnu type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e8m0fnu type."); } }; @@ -368,7 +379,7 @@ class PyBF16Type : public PyConcreteType { MlirType t = mlirBF16TypeGet(context->get()); return PyBF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a bf16 type."); + nb::arg("context").none() = nb::none(), "Create a bf16 type."); } }; @@ -388,7 +399,7 @@ class PyF16Type : public PyConcreteType { MlirType t = mlirF16TypeGet(context->get()); return PyF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f16 type."); + nb::arg("context").none() = nb::none(), "Create a f16 type."); } }; @@ -408,7 +419,7 @@ class PyTF32Type : public PyConcreteType { MlirType t = mlirTF32TypeGet(context->get()); return PyTF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a tf32 type."); + nb::arg("context").none() = nb::none(), "Create a tf32 type."); } }; @@ -428,7 +439,7 @@ class PyF32Type : public PyConcreteType { MlirType t = mlirF32TypeGet(context->get()); return PyF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f32 type."); + nb::arg("context").none() = nb::none(), "Create a f32 type."); } }; @@ -448,7 +459,7 @@ class PyF64Type : public PyConcreteType { MlirType t = mlirF64TypeGet(context->get()); return PyF64Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f64 type."); + nb::arg("context").none() = nb::none(), "Create a f64 type."); } }; @@ -468,7 +479,7 @@ class PyNoneType : public PyConcreteType { MlirType t = mlirNoneTypeGet(context->get()); return PyNoneType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a none type."); + nb::arg("context").none() = nb::none(), "Create a none type."); } }; @@ -490,14 +501,15 @@ class PyComplexType : public PyConcreteType { MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw py::value_error( + throw nb::value_error( (Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + + nb::cast(nb::repr(nb::cast(elementType))) + "' and expected floating point or integer type.") - .str()); + .str() + .c_str()); }, "Create a complex type"); - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, "Returns element type."); @@ -508,22 +520,22 @@ class PyComplexType : public PyConcreteType { // Shaped Type Interface - ShapedType void mlir::PyShapedType::bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_rank", [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, "Returns whether the given shaped type is ranked."); - c.def_property_readonly( + c.def_prop_ro( "rank", [](PyShapedType &self) { self.requireHasRank(); return mlirShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_static_shape", [](PyShapedType &self) -> bool { return mlirShapedTypeHasStaticShape(self); @@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicDim(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); c.def( @@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeGetDimSize(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( @@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicStrideOrOffset(val); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); - c.def_property_readonly( + c.def_prop_ro( "shape", [](PyShapedType &self) { self.requireHasRank(); @@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { void mlir::PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw py::value_error( + throw nb::value_error( "calling this method requires that the type has a rank."); } } @@ -607,15 +619,15 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, py::arg("shape"), - py::arg("element_type"), py::kw_only(), - py::arg("scalable") = py::none(), - py::arg("scalable_dims") = py::none(), - py::arg("loc") = py::none(), "Create a vector type") - .def_property_readonly( + c.def_static("get", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable").none() = nb::none(), + nb::arg("scalable_dims").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a vector type") + .def_prop_ro( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_property_readonly("scalable_dims", [](MlirType self) { + .def_prop_ro("scalable_dims", [](MlirType self) { std::vector scalableDims; size_t rank = static_cast(mlirShapedTypeGetRank(self)); scalableDims.reserve(rank); @@ -627,11 +639,11 @@ class PyVectorType : public PyConcreteType { private: static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, + std::optional scalable, std::optional> scalableDims, DefaultingPyLocation loc) { if (scalable && scalableDims) { - throw py::value_error("'scalable' and 'scalable_dims' kwargs " + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); } @@ -639,10 +651,10 @@ class PyVectorType : public PyConcreteType { MlirType type; if (scalable) { if (scalable->size() != shape.size()) - throw py::value_error("Expected len(scalable) == len(shape)."); + throw nb::value_error("Expected len(scalable) == len(shape)."); SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const py::handle &h) { return h.cast(); })); + *scalable, [](const nb::handle &h) { return nb::cast(h); })); type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); @@ -650,7 +662,7 @@ class PyVectorType : public PyConcreteType { SmallVector scalableDimFlags(shape.size(), false); for (int64_t dim : *scalableDims) { if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) - throw py::value_error("Scalable dimension index out of bounds."); + throw nb::value_error("Scalable dimension index out of bounds."); scalableDimFlags[dim] = true; } type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), @@ -689,17 +701,17 @@ class PyRankedTensorType throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("encoding") = py::none(), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - c.def_property_readonly( - "encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); + c.def_prop_ro("encoding", + [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = + mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); } }; @@ -723,7 +735,7 @@ class PyUnrankedTensorType throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("loc") = py::none(), + nb::arg("element_type"), nb::arg("loc").none() = nb::none(), "Create a unranked tensor type"); } }; @@ -754,10 +766,11 @@ class PyMemRefType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout").none() = nb::none(), + nb::arg("memory_space").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a memref type") + .def_prop_ro( "layout", [](PyMemRefType &self) -> MlirAttribute { return mlirMemRefTypeGetLayout(self); @@ -775,14 +788,14 @@ class PyMemRefType : public PyConcreteType { return {strides, offset}; }, "The strides and offset of the MemRef type.") - .def_property_readonly( + .def_prop_ro( "affine_map", [](PyMemRefType &self) -> PyAffineMap { MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); return PyAffineMap(self.getContext(), map); }, "The layout of the MemRef type as an affine map.") - .def_property_readonly( + .def_prop_ro( "memory_space", [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); @@ -820,9 +833,9 @@ class PyUnrankedMemRefType throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + .def_prop_ro( "memory_space", [](PyUnrankedMemRefType &self) -> std::optional { MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); @@ -851,15 +864,15 @@ class PyTupleType : public PyConcreteType { elements.data()); return PyTupleType(context->getRef(), t); }, - py::arg("elements"), py::arg("context") = py::none(), + nb::arg("elements"), nb::arg("context").none() = nb::none(), "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { return mlirTupleTypeGetType(self, pos); }, - py::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_property_readonly( + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( "num_types", [](PyTupleType &self) -> intptr_t { return mlirTupleTypeGetNumTypes(self); @@ -887,13 +900,14 @@ class PyFunctionType : public PyConcreteType { results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + nb::arg("inputs"), nb::arg("results"), + nb::arg("context").none() = nb::none(), "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( + c.def_prop_ro( "inputs", [](PyFunctionType &self) { MlirType t = self; - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { types.append(mlirFunctionTypeGetInput(t, i)); @@ -901,10 +915,10 @@ class PyFunctionType : public PyConcreteType { return types; }, "Returns the list of input types in the FunctionType."); - c.def_property_readonly( + c.def_prop_ro( "results", [](PyFunctionType &self) { - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { types.append(mlirFunctionTypeGetResult(self, i)); @@ -938,21 +952,21 @@ class PyOpaqueType : public PyConcreteType { toMlirStringRef(typeData)); return PyOpaqueType(context->getRef(), type); }, - py::arg("dialect_namespace"), py::arg("buffer"), - py::arg("context") = py::none(), + nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context").none() = nb::none(), "Create an unregistered (opaque) dialect type."); - c.def_property_readonly( + c.def_prop_ro( "dialect_namespace", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque type as a string."); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the data for the Opaque type as a string."); } @@ -960,7 +974,7 @@ class PyOpaqueType : public PyConcreteType { } // namespace -void mlir::python::populateIRTypes(py::module &m) { +void mlir::python::populateIRTypes(nb::module_ &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c2702190..e5e64a921 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,29 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "PybindUtils.h" +#include +#include #include "Globals.h" #include "IRModule.h" +#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlir, m) { +NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_(m, "_Globals", py::module_local()) - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) + nb::class_(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) .def( "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { @@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, py::kw_only(), + "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](py::type pyClass) { + [](nb::type_object pyClass) { std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast(); + nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, @@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::type &dialectClass, bool replace) -> py::cpp_function { - return py::cpp_function( - [dialectClass, replace](py::type opClass) -> py::type { + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { std::string operationName = - opClass.attr("OPERATION_NAME").cast(); + nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); + nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; return opClass; }); }, - "dialect_class"_a, py::kw_only(), "replace"_a = false, + "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function([mlirTypeID, - replace](py::object typeCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); return typeCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function( - [mlirTypeID, replace](py::object valueCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, replace); return valueCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h similarity index 85% rename from mlir/lib/Bindings/Python/PybindUtils.h rename to mlir/lib/Bindings/Python/NanobindUtils.h index 38462ac8b..3b0f7f698 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -1,4 +1,5 @@ -//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// +//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ +//-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,13 +10,21 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#include + #include "mlir-c/Support.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" -#include -#include +template <> +struct std::iterator_traits { + using value_type = nanobind::handle; + using reference = const value_type; + using pointer = void; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; +}; namespace mlir { namespace python { @@ -54,14 +63,14 @@ class Defaulting { } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template struct MlirDefaultingCaster { - PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); + NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)); - bool load(pybind11::handle src, bool) { + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { // Note that we do want an exception to propagate from here as it will be // the most informative. @@ -76,20 +85,20 @@ struct MlirDefaultingCaster { // code to produce nice error messages (other than "Cannot cast..."). try { value = DefaultingTy{ - pybind11::cast(src)}; + nanobind::cast(src)}; return true; } catch (std::exception &) { return false; } } - static handle cast(DefaultingTy src, return_value_policy policy, - handle parent) { - return pybind11::cast(src, policy); + static handle from_cpp(DefaultingTy src, rv_policy policy, + cleanup_list *cleanup) noexcept { + return nanobind::cast(src, policy); } }; } // namespace detail -} // namespace pybind11 +} // namespace nanobind //------------------------------------------------------------------------------ // Conversion utilities. @@ -100,7 +109,7 @@ namespace mlir { /// Accumulates into a python string from a method that accepts an /// MlirStringCallback. struct PyPrintAccumulator { - pybind11::list parts; + nanobind::list parts; void *getUserData() { return this; } @@ -108,15 +117,15 @@ struct PyPrintAccumulator { return [](MlirStringRef part, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); - pybind11::str pyPart(part.data, + nanobind::str pyPart(part.data, part.length); // Decodes as UTF-8 by default. printAccum->parts.append(std::move(pyPart)); }; } - pybind11::str join() { - pybind11::str delim("", 0); - return delim.attr("join")(parts); + nanobind::str join() { + nanobind::str delim("", 0); + return nanobind::cast(delim.attr("join")(parts)); } }; @@ -124,21 +133,21 @@ struct PyPrintAccumulator { /// or binary. class PyFileAccumulator { public: - PyFileAccumulator(const pybind11::object &fileObject, bool binary) + PyFileAccumulator(const nanobind::object &fileObject, bool binary) : pyWriteFunction(fileObject.attr("write")), binary(binary) {} void *getUserData() { return this; } MlirStringCallback getCallback() { return [](MlirStringRef part, void *userData) { - pybind11::gil_scoped_acquire acquire; + nanobind::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. - pybind11::bytes pyBytes(part.data, part.length); + nanobind::bytes pyBytes(part.data, part.length); accum->pyWriteFunction(pyBytes); } else { - pybind11::str pyStr(part.data, + nanobind::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. accum->pyWriteFunction(pyStr); } @@ -146,7 +155,7 @@ class PyFileAccumulator { } private: - pybind11::object pyWriteFunction; + nanobind::object pyWriteFunction; bool binary; }; @@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator { assert(!accum->invoked && "PySinglePartStringAccumulator called back multiple times"); accum->invoked = true; - accum->value = pybind11::str(part.data, part.length); + accum->value = nanobind::str(part.data, part.length); }; } - pybind11::str takeValue() { + nanobind::str takeValue() { assert(invoked && "PySinglePartStringAccumulator not called back"); return std::move(value); } private: - pybind11::str value; + nanobind::str value; bool invoked = false; }; @@ -208,7 +217,7 @@ struct PySinglePartStringAccumulator { template class Sliceable { protected: - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; /// Transforms `index` into a legal value to access the underlying sequence. /// Returns <0 on failure. @@ -237,7 +246,7 @@ class Sliceable { /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - pybind11::object getItem(intptr_t index) { + nanobind::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { @@ -250,20 +259,20 @@ class Sliceable { ->getRawElement(linearizeIndex(index)) .maybeDownCast(); else - return pybind11::cast( + return nanobind::cast( static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given /// slice. Returns a nullptr object on failure. - pybind11::object getItemSlice(PyObject *slice) { + nanobind::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, &sliceLength) != 0) { PyErr_SetString(PyExc_IndexError, "index out of range"); return {}; } - return pybind11::cast(static_cast(this)->slice( + return nanobind::cast(static_cast(this)->slice( startIndex + start * step, sliceLength, step * extraStep)); } @@ -279,7 +288,7 @@ class Sliceable { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { - throw pybind11::index_error("index out of range"); + throw nanobind::index_error("index out of range"); } return static_cast(this)->getRawElement(linearizeIndex(index)); @@ -304,39 +313,38 @@ class Sliceable { } /// Binds the indexing and length methods in the Python class. - static void bind(pybind11::module &m) { - auto clazz = pybind11::class_(m, Derived::pyClassName, - pybind11::module_local()) + static void bind(nanobind::module_ &m) { + auto clazz = nanobind::class_(m, Derived::pyClassName) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); // Manually implement the sequence protocol via the C API. We do this - // because it is approx 4x faster than via pybind11, largely because that + // because it is approx 4x faster than via nanobind, largely because that // formulation requires a C++ exception to be thrown to detect end of // sequence. // Since we are in a C-context, any C++ exception that happens here // will terminate the program. There is nothing in this implementation // that should throw in a non-terminal way, so we forgo further // exception marshalling. - // See: https://github.com/pybind/pybind11/issues/2842 + // See: https://github.com/pybind/nanobind/issues/2842 auto heap_type = reinterpret_cast(clazz.ptr()); assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && "must be heap type"); heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->length; }; // sq_item is called as part of the sequence protocol for iteration, // list construction, etc. heap_type->as_sequence.sq_item = +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->getItem(index).release().ptr(); }; // mp_subscript is used for both slices and integer lookups. heap_type->as_mapping.mp_subscript = +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); if (!PyErr_Occurred()) { // Integer indexing. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e991deaae..b5dce4fe4 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,12 +8,16 @@ #include "Pass.h" +#include +#include +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -34,16 +38,15 @@ class PyPassManager { MlirPassManager get() { return passManager; } void release() { passManager.ptr = nullptr; } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonPassManagerToCapsule(get())); + nb::object getCapsule() { + return nb::steal(mlirPythonPassManagerToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyPassManager(rawPm), py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); } private: @@ -53,22 +56,23 @@ class PyPassManager { } // namespace /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(py::module &m) { +void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](const std::string &anchorOp, - DefaultingPyMlirContext context) { - MlirPassManager passManager = mlirPassManagerCreateOnOperation( - context->get(), - mlirStringRefCreate(anchorOp.data(), anchorOp.size())); - return new PyPassManager(passManager); - }), - "anchor_op"_a = py::str("any"), "context"_a = py::none(), - "Create a new PassManager for the current (or provided) Context.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyPassManager::getCapsule) + nb::class_(m, "PassManager") + .def( + "__init__", + [](PyPassManager &self, const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); + new (&self) PyPassManager(passManager); + }, + "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") @@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, - "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, - "print_generic_op_form"_a = false, - "tree_printing_dir_path"_a = py::none(), + "large_elements_limit"_a.none() = nb::none(), + "enable_debug_info"_a = false, "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a.none() = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", @@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); return new PyPassManager(passManager); }, - "pipeline"_a, "context"_a = py::none(), + "pipeline"_a, "context"_a.none() = nb::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); }, "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index 3a500d5e8..bc4094352 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populatePassManagerSubmodule(pybind11::module &m); +void populatePassManagerSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 1d8128be9..b2c1de4be 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,14 +8,16 @@ #include "Rewrite.h" +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Rewrite.h" #include "mlir/Config/mlir-config.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; namespace { @@ -54,18 +56,17 @@ class PyFrozenRewritePatternSet { } MlirFrozenRewritePatternSet get() { return set; } - pybind11::object getCapsule() { - return py::reinterpret_steal( + nb::object getCapsule() { + return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) - throw py::error_already_set(); - return py::cast(PyFrozenRewritePatternSet(rawPm), - py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: @@ -75,25 +76,27 @@ class PyFrozenRewritePatternSet { } // namespace /// Create the `mlir.rewrite` here. -void mlir::python::populateRewriteSubmodule(py::module &m) { +void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - py::class_(m, "PDLModule", py::module_local()) - .def(py::init<>([](MlirModule module) { - return mlirPDLPatternModuleFromModule(module); - }), - "module"_a, "Create a PDL module from the given module.") + nb::class_(m, "PDLModule") + .def( + "__init__", + [](PyPDLPatternModule &self, MlirModule module) { + new (&self) + PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + }, + "module"_a, "Create a PDL module from the given module.") .def("freeze", [](PyPDLPatternModule &self) { return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }); -#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg - py::class_(m, "FrozenRewritePatternSet", - py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyFrozenRewritePatternSet::getCapsule) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "FrozenRewritePatternSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( @@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) { auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); if (mlirLogicalResultIsFailure(status)) // FIXME: Not sure this is the right error to throw here. - throw py::value_error("pattern application failed to converge"); + throw nb::value_error("pattern application failed to converge"); }, "module"_a, "set"_a, "Applys the given patterns to the given module greedily while folding " diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index 997b80add..ae89e2b95 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populateRewriteSubmodule(pybind11::module &m); +void populateRewriteSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 10866c11b..b865c9032 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -448,6 +448,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES MainModule.cpp IRAffine.cpp @@ -463,7 +464,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Globals.h IRModule.h Pass.h - PybindUtils.h + NanobindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index ab8a91229..f240d6ef9 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -nanobind>=2.0, <3.0 +nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From 6849bb0d17d1bfd1bcda0ab01844bf8710f1ebdb Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 18 Dec 2024 19:31:32 +0000 Subject: [PATCH 809/915] Revert "[mlir python] Port Python core code to nanobind. (#118583)" This reverts commit 4c9b06adfac283c4027612ac5e010ef5d09a158b. Breakage detected, rolling back. --- mlir/include/mlir/Bindings/Python/IRTypes.h | 2 +- .../mlir/Bindings/Python/PybindAdaptors.h | 10 +- mlir/lib/Bindings/Python/Globals.h | 39 +- mlir/lib/Bindings/Python/IRAffine.cpp | 265 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 672 +++----- mlir/lib/Bindings/Python/IRCore.cpp | 1412 ++++++++--------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 171 +- mlir/lib/Bindings/Python/IRModule.cpp | 57 +- mlir/lib/Bindings/Python/IRModule.h | 332 ++-- mlir/lib/Bindings/Python/IRTypes.cpp | 200 ++- mlir/lib/Bindings/Python/MainModule.cpp | 56 +- mlir/lib/Bindings/Python/Pass.cpp | 58 +- mlir/lib/Bindings/Python/Pass.h | 4 +- .../Python/{NanobindUtils.h => PybindUtils.h} | 84 +- mlir/lib/Bindings/Python/Rewrite.cpp | 43 +- mlir/lib/Bindings/Python/Rewrite.h | 4 +- mlir/python/CMakeLists.txt | 3 +- mlir/python/requirements.txt | 2 +- 18 files changed, 1561 insertions(+), 1853 deletions(-) rename mlir/lib/Bindings/Python/{NanobindUtils.h => PybindUtils.h} (85%) diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index ba9642cf2..9afad4c23 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" namespace mlir { diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index edc69774b..c8233355d 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -374,8 +374,9 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf(std::forward(f), py::name(name), - py::scope(thisClass), extra...); + py::cpp_function cf( + std::forward(f), py::name(name), py::scope(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); thisClass.attr(cf.name()) = py::staticmethod(cf); return *this; } @@ -386,8 +387,9 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf(std::forward(f), py::name(name), - py::scope(thisClass), extra...); + py::cpp_function cf( + std::forward(f), py::name(name), py::scope(thisClass), + py::sibling(py::getattr(thisClass, name, py::none())), extra...); thisClass.attr(cf.name()) = py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); return *this; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 0ec522d14..a022067f5 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,17 +9,18 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H -#include -#include -#include +#include "PybindUtils.h" -#include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include +#include +#include + namespace mlir { namespace python { @@ -56,55 +57,55 @@ class PyGlobals { /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - nanobind::callable pyFunc, + pybind11::function pyFunc, bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. - void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, + void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, bool replace = false); /// Adds a user-friendly value caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. void registerValueCaster(MlirTypeID mlirTypeID, - nanobind::callable valueCaster, + pybind11::function valueCaster, bool replace = false); /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - nanobind::object pyClass); + pybind11::object pyClass); /// Adds a concrete implementation operation class. /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - nanobind::object pyClass, bool replace = false); + pybind11::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. - std::optional + std::optional lookupAttributeBuilder(const std::string &attributeKind); /// Returns the custom type caster for MlirTypeID mlirTypeID. - std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Returns the custom value caster for MlirTypeID mlirTypeID. - std::optional lookupValueCaster(MlirTypeID mlirTypeID, + std::optional lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - std::optional + std::optional lookupDialectClass(const std::string &dialectNamespace); /// Looks up a registered operation class (deriving from OpView) by operation /// name. Note that this may trigger a load of the dialect, which can /// arbitrarily re-enter. - std::optional + std::optional lookupOperationClass(llvm::StringRef operationName); private: @@ -112,15 +113,15 @@ class PyGlobals { /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. - llvm::StringMap dialectClassMap; + llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. - llvm::StringMap operationClassMap; + llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; + llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. - llvm::DenseMap typeCasterMap; + llvm::DenseMap typeCasterMap; /// Map of MlirTypeID to custom value caster. - llvm::DenseMap valueCasterMap; + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 2db690309..b138e131e 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,19 +6,20 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include - #include #include -#include +#include +#include +#include +#include #include #include #include #include "IRModule.h" -#include "NanobindUtils.h" + +#include "PybindUtils.h" + #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -29,7 +30,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -namespace nb = nanobind; +namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -45,23 +46,23 @@ static const char kDumpDocstring[] = /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template -static void pyListToVector(const nb::list &list, +static void pyListToVector(const py::list &list, llvm::SmallVectorImpl &result, StringRef action) { - result.reserve(nb::len(list)); - for (nb::handle item : list) { + result.reserve(py::len(list)); + for (py::handle item : list) { try { - result.push_back(nb::cast(item)); - } catch (nb::cast_error &err) { + result.push_back(item.cast()); + } catch (py::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); - throw std::runtime_error(msg.c_str()); - } catch (std::runtime_error &err) { + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); - throw std::runtime_error(msg.c_str()); + throw py::cast_error(msg); } } } @@ -93,7 +94,7 @@ class PyConcreteAffineExpr : public BaseTy { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = nb::class_; + using ClassTy = py::class_; using IsAFunctionTy = bool (*)(MlirAffineExpr); PyConcreteAffineExpr() = default; @@ -104,25 +105,24 @@ class PyConcreteAffineExpr : public BaseTy { static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast affine expression to ") + + auto origRepr = py::repr(py::cast(orig)).cast(); + throw py::value_error((Twine("Cannot cast affine expression to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str() - .c_str()); + .str()); } return orig; } - static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nb::init(), nb::arg("expr")); + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); + cls.def(py::init(), py::arg("expr")); cls.def_static( "isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { return DerivedTy::isaFunction(otherAffineExpr); }, - nb::arg("other")); + py::arg("other")); DerivedTy::bindDerived(cls); } @@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), - nb::arg("context").none() = nb::none()); - c.def_prop_ro("value", [](PyAffineConstantExpr &self) { + c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none()); + c.def_property_readonly("value", [](PyAffineConstantExpr &self) { return mlirAffineConstantExprGetValue(self); }); } @@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none()); - c.def_prop_ro("position", [](PyAffineDimExpr &self) { + c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineDimExpr &self) { return mlirAffineDimExprGetPosition(self); }); } @@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none()); - c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { + c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none()); + c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { return mlirAffineSymbolExprGetPosition(self); }); } @@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); - c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); + c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); + c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); } }; @@ -365,14 +365,15 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); } -nb::object PyAffineExpr::getCapsule() { - return nb::steal(mlirPythonAffineExprToCapsule(*this)); +py::object PyAffineExpr::getCapsule() { + return py::reinterpret_steal( + mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) - throw nb::python_error(); + throw py::error_already_set(); return PyAffineExpr( PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), rawAffineExpr); @@ -423,14 +424,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); } -nb::object PyAffineMap::getCapsule() { - return nb::steal(mlirPythonAffineMapToCapsule(*this)); +py::object PyAffineMap::getCapsule() { + return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) - throw nb::python_error(); + throw py::error_already_set(); return PyAffineMap( PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), rawAffineMap); @@ -453,10 +454,11 @@ class PyIntegerSetConstraint { bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - static void bind(nb::module_ &m) { - nb::class_(m, "IntegerSetConstraint") - .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) - .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); + static void bind(py::module &m) { + py::class_(m, "IntegerSetConstraint", + py::module_local()) + .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) + .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); } private: @@ -499,25 +501,27 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); } -nb::object PyIntegerSet::getCapsule() { - return nb::steal(mlirPythonIntegerSetToCapsule(*this)); +py::object PyIntegerSet::getCapsule() { + return py::reinterpret_steal( + mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) - throw nb::python_error(); + throw py::error_already_set(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } -void mlir::python::populateIRAffine(nb::module_ &m) { +void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- - nb::class_(m, "AffineExpr") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) + py::class_(m, "AffineExpr", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) .def("__add__", &PyAffineAddExpr::get) .def("__add__", &PyAffineAddExpr::getRHSConstant) @@ -554,7 +558,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", - [](PyAffineExpr &self, nb::object &other) { return false; }) + [](PyAffineExpr &self, py::object &other) { return false; }) .def("__str__", [](PyAffineExpr &self) { PyPrintAccumulator printAccum; @@ -575,7 +579,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](PyAffineExpr &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_prop_ro( + .def_property_readonly( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) .def("compose", @@ -628,16 +632,16 @@ void mlir::python::populateIRAffine(nb::module_ &m) { .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, "Gets an affine expression containing the rounded-up result " "of dividing an expression by a constant.") - .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), - nb::arg("context").none() = nb::none(), + .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), + py::arg("context") = py::none(), "Gets a constant affine expression with the given value.") .def_static( - "get_dim", &PyAffineDimExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none(), + "get_dim", &PyAffineDimExpr::get, py::arg("position"), + py::arg("context") = py::none(), "Gets an affine expression of a dimension at the given position.") .def_static( - "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), - nb::arg("context").none() = nb::none(), + "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), + py::arg("context") = py::none(), "Gets an affine expression of a symbol at the given position.") .def( "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, @@ -655,12 +659,13 @@ void mlir::python::populateIRAffine(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- - nb::class_(m, "AffineMap") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) + py::class_(m, "AffineMap", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) .def("__eq__", [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) + .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) .def("__str__", [](PyAffineMap &self) { PyPrintAccumulator printAccum; @@ -682,7 +687,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return static_cast(llvm::hash_value(self.get().ptr)); }) .def_static("compress_unused_symbols", - [](nb::list affineMaps, DefaultingPyMlirContext context) { + [](py::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; pyListToVector( affineMaps, maps, "attempting to create an AffineMap"); @@ -699,7 +704,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { res.emplace_back(context->getRef(), m); return res; }) - .def_prop_ro( + .def_property_readonly( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, "Context that owns the Affine Map") @@ -708,7 +713,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( @@ -718,8 +723,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { affineExprs.size(), affineExprs.data()); return PyAffineMap(context->getRef(), map); }, - nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), - nb::arg("context").none() = nb::none(), + py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), + py::arg("context") = py::none(), "Gets a map with the given expressions as results.") .def_static( "get_constant", @@ -728,7 +733,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirAffineMapConstantGet(context->get(), value); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets an affine map with a single constant result") .def_static( "get_empty", @@ -736,7 +741,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("context").none() = nb::none(), "Gets an empty affine map.") + py::arg("context") = py::none(), "Gets an empty affine map.") .def_static( "get_identity", [](intptr_t nDims, DefaultingPyMlirContext context) { @@ -744,7 +749,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirAffineMapMultiDimIdentityGet(context->get(), nDims); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("n_dims"), nb::arg("context").none() = nb::none(), + py::arg("n_dims"), py::arg("context") = py::none(), "Gets an identity map with the given number of dimensions.") .def_static( "get_minor_identity", @@ -754,8 +759,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("n_dims"), nb::arg("n_results"), - nb::arg("context").none() = nb::none(), + py::arg("n_dims"), py::arg("n_results"), + py::arg("context") = py::none(), "Gets a minor identity map with the given number of dimensions and " "results.") .def_static( @@ -763,13 +768,13 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](std::vector permutation, DefaultingPyMlirContext context) { if (!isPermutation(permutation)) - throw std::runtime_error("Invalid permutation when attempting to " - "create an AffineMap"); + throw py::cast_error("Invalid permutation when attempting to " + "create an AffineMap"); MlirAffineMap affineMap = mlirAffineMapPermutationGet( context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); }, - nb::arg("permutation"), nb::arg("context").none() = nb::none(), + py::arg("permutation"), py::arg("context") = py::none(), "Gets an affine map that permutes its inputs.") .def( "get_submap", @@ -777,33 +782,33 @@ void mlir::python::populateIRAffine(nb::module_ &m) { intptr_t numResults = mlirAffineMapGetNumResults(self); for (intptr_t pos : resultPos) { if (pos < 0 || pos >= numResults) - throw nb::value_error("result position out of bounds"); + throw py::value_error("result position out of bounds"); } MlirAffineMap affineMap = mlirAffineMapGetSubMap( self, resultPos.size(), resultPos.data()); return PyAffineMap(self.getContext(), affineMap); }, - nb::arg("result_positions")) + py::arg("result_positions")) .def( "get_major_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw nb::value_error("number of results out of bounds"); + throw py::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMajorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - nb::arg("n_results")) + py::arg("n_results")) .def( "get_minor_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw nb::value_error("number of results out of bounds"); + throw py::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - nb::arg("n_results")) + py::arg("n_results")) .def( "replace", [](PyAffineMap &self, PyAffineExpr &expression, @@ -813,37 +818,39 @@ void mlir::python::populateIRAffine(nb::module_ &m) { self, expression, replacement, numResultDims, numResultSyms); return PyAffineMap(self.getContext(), affineMap); }, - nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), - nb::arg("n_result_syms")) - .def_prop_ro( + py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), + py::arg("n_result_syms")) + .def_property_readonly( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_prop_ro("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_prop_ro( + .def_property_readonly("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_property_readonly( "n_dims", [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_prop_ro( + .def_property_readonly( "n_inputs", [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_prop_ro( + .def_property_readonly( "n_symbols", [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_prop_ro("results", - [](PyAffineMap &self) { return PyAffineMapExprList(self); }); + .def_property_readonly("results", [](PyAffineMap &self) { + return PyAffineMapExprList(self); + }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- - nb::class_(m, "IntegerSet") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) + py::class_(m, "IntegerSet", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) + .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -864,7 +871,7 @@ void mlir::python::populateIRAffine(nb::module_ &m) { [](PyIntegerSet &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_prop_ro( + .def_property_readonly( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( @@ -872,14 +879,14 @@ void mlir::python::populateIRAffine(nb::module_ &m) { kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, + [](intptr_t numDims, intptr_t numSymbols, py::list exprs, std::vector eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) - throw nb::value_error( + throw py::value_error( "Expected the number of constraints to match " "that of equality flags"); - if (exprs.size() == 0) - throw nb::value_error("Expected non-empty list of constraints"); + if (exprs.empty()) + throw py::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not @@ -894,8 +901,8 @@ void mlir::python::populateIRAffine(nb::module_ &m) { affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, - nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), - nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) + py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), + py::arg("eq_flags"), py::arg("context") = py::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, @@ -904,20 +911,20 @@ void mlir::python::populateIRAffine(nb::module_ &m) { mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, - nb::arg("num_dims"), nb::arg("num_symbols"), - nb::arg("context").none() = nb::none()) + py::arg("num_dims"), py::arg("num_symbols"), + py::arg("context") = py::none()) .def( "get_replaced", - [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, + [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) - throw nb::value_error( + throw py::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) - throw nb::value_error( + throw py::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); @@ -933,30 +940,30 @@ void mlir::python::populateIRAffine(nb::module_ &m) { numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }, - nb::arg("dim_exprs"), nb::arg("symbol_exprs"), - nb::arg("num_result_dims"), nb::arg("num_result_symbols")) - .def_prop_ro("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_prop_ro( + py::arg("dim_exprs"), py::arg("symbol_exprs"), + py::arg("num_result_dims"), py::arg("num_result_symbols")) + .def_property_readonly("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_property_readonly( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_prop_ro( + .def_property_readonly( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_prop_ro( + .def_property_readonly( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_prop_ro("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_prop_ro("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_prop_ro("constraints", [](PyIntegerSet &self) { + .def_property_readonly("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_property_readonly("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_property_readonly("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index f9656eb23..cc9532f4e 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,29 +6,23 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include -#include - -#include #include -#include #include #include #include "IRModule.h" -#include "NanobindUtils.h" -#include "mlir-c/BuiltinAttributes.h" -#include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" + +#include "PybindUtils.h" +#include + #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" -namespace nb = nanobind; -using namespace nanobind::literals; +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" + +namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -129,119 +123,10 @@ subsequent processing. namespace { -struct nb_buffer_info { - void *ptr = nullptr; - ssize_t itemsize = 0; - ssize_t size = 0; - const char *format = nullptr; - ssize_t ndim = 0; - SmallVector shape; - SmallVector strides; - bool readonly = false; - - nb_buffer_info( - void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, - SmallVector shape_in, SmallVector strides_in, - bool readonly = false, - std::unique_ptr owned_view_in = - std::unique_ptr(nullptr, nullptr)) - : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), - shape(std::move(shape_in)), strides(std::move(strides_in)), - readonly(readonly), owned_view(std::move(owned_view_in)) { - size = 1; - for (ssize_t i = 0; i < ndim; ++i) { - size *= shape[i]; - } - } - - explicit nb_buffer_info(Py_buffer *view) - : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, - {view->shape, view->shape + view->ndim}, - // TODO(phawkins): check for null strides - {view->strides, view->strides + view->ndim}, - view->readonly != 0, - std::unique_ptr( - view, PyBuffer_Release)) {} - - nb_buffer_info(const nb_buffer_info &) = delete; - nb_buffer_info(nb_buffer_info &&) = default; - nb_buffer_info &operator=(const nb_buffer_info &) = delete; - nb_buffer_info &operator=(nb_buffer_info &&) = default; - -private: - std::unique_ptr owned_view; -}; - -class nb_buffer : public nb::object { - NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); - - nb_buffer_info request() const { - int flags = PyBUF_STRIDES | PyBUF_FORMAT; - auto *view = new Py_buffer(); - if (PyObject_GetBuffer(ptr(), view, flags) != 0) { - delete view; - throw nb::python_error(); - } - return nb_buffer_info(view); - } -}; - -template -struct nb_format_descriptor {}; - -template <> -struct nb_format_descriptor { - static const char *format() { return "?"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "b"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "B"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "h"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "H"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "i"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "I"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "q"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "Q"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "f"; } -}; -template <> -struct nb_format_descriptor { - static const char *format() { return "d"; } -}; - static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -static MlirStringRef toMlirStringRef(const nb::bytes &s) { - return mlirStringRefCreate(static_cast(s.data()), s.size()); -} - class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; @@ -257,9 +142,9 @@ class PyAffineMapAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); return PyAffineMapAttribute(affineMap.getContext(), attr); }, - nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_prop_ro("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); + py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_property_readonly("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); } }; @@ -279,24 +164,25 @@ class PyIntegerSetAttribute MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); return PyIntegerSetAttribute(integerSet.getContext(), attr); }, - nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); } }; template -static T pyTryCast(nb::handle object) { +static T pyTryCast(py::handle object) { try { - return nb::cast(object); - } catch (nb::cast_error &err) { - std::string msg = std::string("Invalid attribute when attempting to " - "create an ArrayAttribute (") + - err.what() + ")"; - throw std::runtime_error(msg.c_str()); - } catch (std::runtime_error &err) { + return object.cast(); + } catch (py::cast_error &err) { + std::string msg = + std::string( + "Invalid attribute when attempting to create an ArrayAttribute (") + + err.what() + ")"; + throw py::cast_error(msg); + } catch (py::reference_cast_error &err) { std::string msg = std::string("Invalid attribute (None?) when attempting " "to create an ArrayAttribute (") + err.what() + ")"; - throw std::runtime_error(msg.c_str()); + throw py::cast_error(msg); } } @@ -319,13 +205,14 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { EltTy dunderNext() { // Throw if the index has reached the end. if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) - throw nb::stop_iteration(); + throw py::stop_iteration(); return DerivedT::getElement(attr.get(), nextIndex++); } /// Bind the iterator class. - static void bind(nb::module_ &m) { - nb::class_(m, DerivedT::pyIteratorName) + static void bind(py::module &m) { + py::class_(m, DerivedT::pyIteratorName, + py::module_local()) .def("__iter__", &PyDenseArrayIterator::dunderIter) .def("__next__", &PyDenseArrayIterator::dunderNext); } @@ -343,35 +230,17 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { /// Bind the attribute class. static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { // Bind the constructor. - if constexpr (std::is_same_v) { - c.def_static( - "get", - [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { - std::vector values; - for (nb::handle py_value : py_values) { - int is_true = PyObject_IsTrue(py_value.ptr()); - if (is_true < 0) { - throw nb::python_error(); - } - values.push_back(is_true); - } - return getAttribute(values, ctx->getRef()); - }, - nb::arg("values"), nb::arg("context").none() = nb::none(), - "Gets a uniqued dense array attribute"); - } else { - c.def_static( - "get", - [](const std::vector &values, DefaultingPyMlirContext ctx) { - return getAttribute(values, ctx->getRef()); - }, - nb::arg("values"), nb::arg("context").none() = nb::none(), - "Gets a uniqued dense array attribute"); - } + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); + }, + py::arg("values"), py::arg("context") = py::none(), + "Gets a uniqued dense array attribute"); // Bind the array methods. c.def("__getitem__", [](DerivedT &arr, intptr_t i) { if (i >= mlirDenseArrayGetNumElements(arr)) - throw nb::index_error("DenseArray index out of range"); + throw py::index_error("DenseArray index out of range"); return arr.getItem(i); }); c.def("__len__", [](const DerivedT &arr) { @@ -379,13 +248,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { }); c.def("__iter__", [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, const nb::list &extras) { + c.def("__add__", [](DerivedT &arr, const py::list &extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); - values.reserve(numOldElements + nb::len(extras)); + values.reserve(numOldElements + py::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) values.push_back(arr.getItem(i)); - for (nb::handle attr : extras) + for (py::handle attr : extras) values.push_back(pyTryCast(attr)); return getAttribute(values, arr.getContext()); }); @@ -489,12 +358,13 @@ class PyArrayAttribute : public PyConcreteAttribute { MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) - throw nb::stop_iteration(); + throw py::stop_iteration(); return mlirArrayAttrGetElement(attr.get(), nextIndex++); } - static void bind(nb::module_ &m) { - nb::class_(m, "ArrayAttributeIterator") + static void bind(py::module &m) { + py::class_(m, "ArrayAttributeIterator", + py::module_local()) .def("__iter__", &PyArrayAttributeIterator::dunderIter) .def("__next__", &PyArrayAttributeIterator::dunderNext); } @@ -511,9 +381,9 @@ class PyArrayAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](nb::list attributes, DefaultingPyMlirContext context) { + [](py::list attributes, DefaultingPyMlirContext context) { SmallVector mlirAttributes; - mlirAttributes.reserve(nb::len(attributes)); + mlirAttributes.reserve(py::len(attributes)); for (auto attribute : attributes) { mlirAttributes.push_back(pyTryCast(attribute)); } @@ -521,12 +391,12 @@ class PyArrayAttribute : public PyConcreteAttribute { context->get(), mlirAttributes.size(), mlirAttributes.data()); return PyArrayAttribute(context->getRef(), attr); }, - nb::arg("attributes"), nb::arg("context").none() = nb::none(), + py::arg("attributes"), py::arg("context") = py::none(), "Gets a uniqued Array attribute"); c.def("__getitem__", [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) - throw nb::index_error("ArrayAttribute index out of range"); + throw py::index_error("ArrayAttribute index out of range"); return arr.getItem(i); }) .def("__len__", @@ -536,13 +406,13 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { + c.def("__add__", [](PyArrayAttribute arr, py::list extras) { std::vector attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); - attributes.reserve(numOldElements + nb::len(extras)); + attributes.reserve(numOldElements + py::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) attributes.push_back(arr.getItem(i)); - for (nb::handle attr : extras) + for (py::handle attr : extras) attributes.push_back(pyTryCast(attr)); MlirAttribute arrayAttr = mlirArrayAttrGet( arr.getContext()->get(), attributes.size(), attributes.data()); @@ -570,7 +440,7 @@ class PyFloatAttribute : public PyConcreteAttribute { throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, - nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), + py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", @@ -579,7 +449,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF32TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", @@ -588,10 +458,10 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF64TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_prop_ro("value", mlirFloatAttrGetValueDouble, - "Returns the value of the float attribute"); + c.def_property_readonly("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); c.def("__float__", mlirFloatAttrGetValueDouble, "Converts the value of the float attribute to a Python float"); } @@ -611,20 +481,20 @@ class PyIntegerAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirIntegerAttrGet(type, value); return PyIntegerAttribute(type.getContext(), attr); }, - nb::arg("type"), nb::arg("value"), + py::arg("type"), py::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_prop_ro("value", toPyInt, - "Returns the value of the integer attribute"); + c.def_property_readonly("value", toPyInt, + "Returns the value of the integer attribute"); c.def("__int__", toPyInt, "Converts the value of the integer attribute to a Python int"); - c.def_prop_ro_static("static_typeid", - [](nb::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + c.def_property_readonly_static("static_typeid", + [](py::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } private: - static int64_t toPyInt(PyIntegerAttribute &self) { + static py::int_ toPyInt(PyIntegerAttribute &self) { MlirType type = mlirAttributeGetType(self); if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) return mlirIntegerAttrGetValueInt(self); @@ -648,10 +518,10 @@ class PyBoolAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirBoolAttrGet(context->get(), value); return PyBoolAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets an uniqued bool attribute"); - c.def_prop_ro("value", mlirBoolAttrGetValue, - "Returns the value of the bool attribute"); + c.def_property_readonly("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); c.def("__bool__", mlirBoolAttrGetValue, "Converts the value of the bool attribute to a Python bool"); } @@ -685,9 +555,9 @@ class PySymbolRefAttribute : public PyConcreteAttribute { DefaultingPyMlirContext context) { return PySymbolRefAttribute::fromList(symbols, context.resolve()); }, - nb::arg("symbols"), nb::arg("context").none() = nb::none(), + py::arg("symbols"), py::arg("context") = py::none(), "Gets a uniqued SymbolRef attribute from a list of symbol names"); - c.def_prop_ro( + c.def_property_readonly( "value", [](PySymbolRefAttribute &self) { std::vector symbols = { @@ -719,13 +589,13 @@ class PyFlatSymbolRefAttribute mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued FlatSymbolRef attribute"); - c.def_prop_ro( + c.def_property_readonly( "value", [](PyFlatSymbolRefAttribute &self) { MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return nb::str(stringRef.data, stringRef.length); + return py::str(stringRef.data, stringRef.length); }, "Returns the value of the FlatSymbolRef attribute as a string"); } @@ -742,29 +612,29 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, nb_buffer buffer, PyType &type, + [](std::string dialectNamespace, py::buffer buffer, PyType &type, DefaultingPyMlirContext context) { - const nb_buffer_info bufferInfo = buffer.request(); + const py::buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( context->get(), toMlirStringRef(dialectNamespace), bufferSize, static_cast(bufferInfo.ptr), type); return PyOpaqueAttribute(context->getRef(), attr); }, - nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), - nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); - c.def_prop_ro( + py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), + py::arg("context") = py::none(), "Gets an Opaque attribute."); + c.def_property_readonly( "dialect_namespace", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); + return py::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque attribute as a string"); - c.def_prop_ro( + c.def_property_readonly( "data", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return nb::bytes(stringRef.data, stringRef.length); + return py::bytes(stringRef.data, stringRef.length); }, "Returns the data for the Opaqued attributes as `bytes`"); } @@ -786,16 +656,7 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), - "Gets a uniqued string attribute"); - c.def_static( - "get", - [](nb::bytes value, DefaultingPyMlirContext context) { - MlirAttribute attr = - mlirStringAttrGet(context->get(), toMlirStringRef(value)); - return PyStringAttribute(context->getRef(), attr); - }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued string attribute"); c.def_static( "get_typed", @@ -804,20 +665,20 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); }, - nb::arg("type"), nb::arg("value"), + py::arg("type"), py::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_prop_ro( + c.def_property_readonly( "value", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::str(stringRef.data, stringRef.length); + return py::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); - c.def_prop_ro( + c.def_property_readonly( "value_bytes", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return nb::bytes(stringRef.data, stringRef.length); + return py::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); } @@ -832,11 +693,12 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(nb::list attributes, std::optional explicitType, + getFromList(py::list attributes, std::optional explicitType, DefaultingPyMlirContext contextWrapper) { - const size_t numAttributes = nb::len(attributes); + + const size_t numAttributes = py::len(attributes); if (numAttributes == 0) - throw nb::value_error("Attributes list must be non-empty."); + throw py::value_error("Attributes list must be non-empty."); MlirType shapedType; if (explicitType) { @@ -846,8 +708,8 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " - << nb::cast(nb::repr(nb::cast(*explicitType))); - throw nb::value_error(message.c_str()); + << py::repr(py::cast(*explicitType)); + throw py::value_error(message); } shapedType = *explicitType; } else { @@ -860,7 +722,7 @@ class PyDenseElementsAttribute SmallVector mlirAttributes; mlirAttributes.reserve(numAttributes); - for (const nb::handle &attribute : attributes) { + for (const py::handle &attribute : attributes) { MlirAttribute mlirAttribute = pyTryCast(attribute); MlirType attrType = mlirAttributeGetType(mlirAttribute); mlirAttributes.push_back(mlirAttribute); @@ -869,11 +731,9 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "All attributes must be of the same type and match " - << "the type parameter: expected=" - << nb::cast(nb::repr(nb::cast(shapedType))) - << ", but got=" - << nb::cast(nb::repr(nb::cast(attrType))); - throw nb::value_error(message.c_str()); + << "the type parameter: expected=" << py::repr(py::cast(shapedType)) + << ", but got=" << py::repr(py::cast(attrType)); + throw py::value_error(message); } } @@ -884,7 +744,7 @@ class PyDenseElementsAttribute } static PyDenseElementsAttribute - getFromBuffer(nb_buffer array, bool signless, + getFromBuffer(py::buffer array, bool signless, std::optional explicitType, std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { @@ -895,7 +755,7 @@ class PyDenseElementsAttribute } Py_buffer view; if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { - throw nb::python_error(); + throw py::error_already_set(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); @@ -918,25 +778,25 @@ class PyDenseElementsAttribute if (!mlirAttributeIsAInteger(elementAttr) && !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); - throw nb::value_error(message.c_str()); + message.append(py::repr(py::cast(elementAttr))); + throw py::value_error(message); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; - message.append(nb::cast(nb::repr(nb::cast(shapedType)))); - throw nb::value_error(message.c_str()); + message.append(py::repr(py::cast(shapedType))); + throw py::value_error(message); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); if (!mlirTypeEqual(shapedElementType, attrType)) { std::string message = "Shaped element type and attribute type must be equal: shaped="; - message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + message.append(py::repr(py::cast(shapedType))); message.append(", element="); - message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); - throw nb::value_error(message.c_str()); + message.append(py::repr(py::cast(elementAttr))); + throw py::value_error(message); } MlirAttribute elements = @@ -946,7 +806,7 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - std::unique_ptr accessBuffer() { + py::buffer_info accessBuffer() { MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -1029,36 +889,32 @@ class PyDenseElementsAttribute static void bindDerived(ClassTy &c) { c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("signless") = true, - nb::arg("type").none() = nb::none(), - nb::arg("shape").none() = nb::none(), - nb::arg("context").none() = nb::none(), + py::arg("array"), py::arg("signless") = true, + py::arg("type") = py::none(), py::arg("shape") = py::none(), + py::arg("context") = py::none(), kDenseElementsAttrGetDocstring) .def_static("get", PyDenseElementsAttribute::getFromList, - nb::arg("attrs"), nb::arg("type").none() = nb::none(), - nb::arg("context").none() = nb::none(), + py::arg("attrs"), py::arg("type") = py::none(), + py::arg("context") = py::none(), kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, - nb::arg("shaped_type"), nb::arg("element_attr"), + py::arg("shaped_type"), py::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") - .def_prop_ro("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def("get_splat_value", [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw nb::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }); + .def_property_readonly("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", + [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw py::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }) + .def_buffer(&PyDenseElementsAttribute::accessBuffer); } - static PyType_Slot slots[]; - private: - static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); - static void bf_releasebuffer(PyObject *, Py_buffer *buffer); - static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; @@ -1183,27 +1039,27 @@ class PyDenseElementsAttribute return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); } - // There is a complication for boolean numpy arrays, as numpy represents - // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 - // booleans per byte. + // There is a complication for boolean numpy arrays, as numpy represents them + // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans + // per byte. static MlirAttribute getBitpackedAttributeFromBooleanBuffer( Py_buffer &view, std::optional> explicitShape, MlirContext &context) { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian - // systems we will throw - throw nb::type_error("Constructing a bit-packed MLIR attribute is " + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a bit-packed MLIR attribute is " "unsupported on big-endian systems"); } - nb::ndarray, nb::c_contig> unpackedArray( - /*data=*/static_cast(view.buf), - /*shape=*/{static_cast(view.len)}); - nb::module_ numpy = nb::module_::import_("numpy"); - nb::object packbitsFunc = numpy.attr("packbits"); - nb::object packedBooleans = - packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); - nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); + py::array_t unpackedArray(view.len, + static_cast(view.buf)); + + py::module numpy = py::module::import("numpy"); + py::object packbitsFunc = numpy.attr("packbits"); + py::object packedBooleans = + packbitsFunc(unpackedArray, "bitorder"_a = "little"); + py::buffer_info pythonBuffer = packedBooleans.cast().request(); MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); @@ -1217,11 +1073,11 @@ class PyDenseElementsAttribute // This does the opposite transformation of // `getBitpackedAttributeFromBooleanBuffer` - std::unique_ptr getBooleanBufferFromBitpackedAttribute() { + py::buffer_info getBooleanBufferFromBitpackedAttribute() { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian - // systems we will throw - throw nb::type_error("Constructing a numpy array from a MLIR attribute " + // Given we have no good way of testing the behavior on big-endian systems + // we will throw + throw py::type_error("Constructing a numpy array from a MLIR attribute " "is unsupported on big-endian systems"); } @@ -1229,24 +1085,21 @@ class PyDenseElementsAttribute int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); uint8_t *bitpackedData = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); - nb::ndarray, nb::c_contig> packedArray( - /*data=*/bitpackedData, - /*shape=*/{static_cast(numBitpackedBytes)}); + py::array_t packedArray(numBitpackedBytes, bitpackedData); - nb::module_ numpy = nb::module_::import_("numpy"); - nb::object unpackbitsFunc = numpy.attr("unpackbits"); - nb::object equalFunc = numpy.attr("equal"); - nb::object reshapeFunc = numpy.attr("reshape"); - nb::object unpackedBooleans = - unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); + py::module numpy = py::module::import("numpy"); + py::object unpackbitsFunc = numpy.attr("unpackbits"); + py::object equalFunc = numpy.attr("equal"); + py::object reshapeFunc = numpy.attr("reshape"); + py::array unpackedBooleans = + unpackbitsFunc(packedArray, "bitorder"_a = "little"); // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. // We need to: // 1. Slice away the padded bits // 2. Make the boolean array have the correct shape // 3. Convert the array to a boolean array - unpackedBooleans = unpackedBooleans[nb::slice( - nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; + unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; unpackedBooleans = equalFunc(unpackedBooleans, 1); MlirType shapedType = mlirAttributeGetType(*this); @@ -1257,15 +1110,15 @@ class PyDenseElementsAttribute } unpackedBooleans = reshapeFunc(unpackedBooleans, shape); - // Make sure the returned nb::buffer_view claims ownership of the data in + // Make sure the returned py::buffer_view claims ownership of the data in // `pythonBuffer` so it remains valid when Python reads it - nb_buffer pythonBuffer = nb::cast(unpackedBooleans); - return std::make_unique(pythonBuffer.request()); + py::buffer pythonBuffer = unpackedBooleans.cast(); + return pythonBuffer.request(); } template - std::unique_ptr - bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { + py::buffer_info bufferInfo(MlirType shapedType, + const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -1289,69 +1142,19 @@ class PyDenseElementsAttribute } strides.push_back(sizeof(Type)); } - const char *format; + std::string format; if (explicitFormat) { format = explicitFormat; } else { - format = nb_format_descriptor::format(); + format = py::format_descriptor::format(); } - return std::make_unique( - data, sizeof(Type), format, rank, std::move(shape), std::move(strides), - /*readonly=*/true); + return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, + /*readonly=*/true); } }; // namespace -PyType_Slot PyDenseElementsAttribute::slots[] = { - {Py_bf_getbuffer, - reinterpret_cast(PyDenseElementsAttribute::bf_getbuffer)}, - {Py_bf_releasebuffer, - reinterpret_cast(PyDenseElementsAttribute::bf_releasebuffer)}, - {0, nullptr}, -}; - -/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, - Py_buffer *view, - int flags) { - view->obj = nullptr; - std::unique_ptr info; - try { - auto *attr = nb::cast(nb::handle(obj)); - info = attr->accessBuffer(); - } catch (nb::python_error &e) { - e.restore(); - nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); - return -1; - } - view->obj = obj; - view->ndim = 1; - view->buf = info->ptr; - view->itemsize = info->itemsize; - view->len = info->itemsize; - for (auto s : info->shape) { - view->len *= s; - } - view->readonly = info->readonly; - if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { - view->format = const_cast(info->format); - } - if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { - view->ndim = static_cast(info->ndim); - view->strides = info->strides.data(); - view->shape = info->shape.data(); - } - view->suboffsets = nullptr; - view->internal = info.release(); - Py_INCREF(obj); - return 0; -} - -/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, - Py_buffer *view) { - delete reinterpret_cast(view->internal); -} - -/// Refinement of the PyDenseElementsAttribute for attributes containing -/// integer (and boolean) values. Supports element access. +/// Refinement of the PyDenseElementsAttribute for attributes containing integer +/// (and boolean) values. Supports element access. class PyDenseIntElementsAttribute : public PyConcreteAttribute { @@ -1360,11 +1163,11 @@ class PyDenseIntElementsAttribute static constexpr const char *pyClassName = "DenseIntElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - /// Returns the element at the given linear position. Asserts if the index - /// is out of range. - nb::object dunderGetItem(intptr_t pos) { + /// Returns the element at the given linear position. Asserts if the index is + /// out of range. + py::int_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1372,7 +1175,7 @@ class PyDenseIntElementsAttribute assert(mlirTypeIsAInteger(type) && "expected integer element type in dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. nb::int_ is implicitly constructible + // elemental type of the attribute. py::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. @@ -1380,38 +1183,38 @@ class PyDenseIntElementsAttribute bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { if (width == 1) { - return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); + return mlirDenseElementsAttrGetBoolValue(*this, pos); } if (width == 8) { - return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); + return mlirDenseElementsAttrGetUInt8Value(*this, pos); } if (width == 16) { - return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); + return mlirDenseElementsAttrGetUInt16Value(*this, pos); } if (width == 32) { - return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); + return mlirDenseElementsAttrGetUInt32Value(*this, pos); } if (width == 64) { - return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); + return mlirDenseElementsAttrGetUInt64Value(*this, pos); } } else { if (width == 1) { - return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); + return mlirDenseElementsAttrGetBoolValue(*this, pos); } if (width == 8) { - return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); + return mlirDenseElementsAttrGetInt8Value(*this, pos); } if (width == 16) { - return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); + return mlirDenseElementsAttrGetInt16Value(*this, pos); } if (width == 32) { - return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); + return mlirDenseElementsAttrGetInt32Value(*this, pos); } if (width == 64) { - return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); + return mlirDenseElementsAttrGetInt64Value(*this, pos); } } - throw nb::type_error("Unsupported integer type"); + throw py::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -1428,7 +1231,7 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, + getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, std::optional alignment, bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { @@ -1441,7 +1244,7 @@ class PyDenseResourceElementsAttribute int flags = PyBUF_STRIDES; std::unique_ptr view = std::make_unique(); if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { - throw nb::python_error(); + throw py::error_already_set(); } // This scope releaser will only release if we haven't yet transferred @@ -1486,12 +1289,12 @@ class PyDenseResourceElementsAttribute } static void bindDerived(ClassTy &c) { - c.def_static( - "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, - nb::arg("array"), nb::arg("name"), nb::arg("type"), - nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, - nb::arg("context").none() = nb::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); + c.def_static("get_from_buffer", + PyDenseResourceElementsAttribute::getFromBuffer, + py::arg("array"), py::arg("name"), py::arg("type"), + py::arg("alignment") = py::none(), + py::arg("is_mutable") = false, py::arg("context") = py::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); } }; @@ -1515,12 +1318,12 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](nb::dict attributes, DefaultingPyMlirContext context) { + [](py::dict attributes, DefaultingPyMlirContext context) { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); - for (std::pair it : attributes) { - auto &mlirAttr = nb::cast(it.second); - auto name = nb::cast(it.first); + for (auto &it : attributes) { + auto &mlirAttr = it.second.cast(); + auto name = it.first.cast(); mlirNamedAttributes.push_back(mlirNamedAttributeGet( mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), toMlirStringRef(name)), @@ -1531,18 +1334,18 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), + py::arg("value") = py::dict(), py::arg("context") = py::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) - throw nb::key_error("attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw nb::index_error("attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -1562,25 +1365,25 @@ class PyDenseFPElementsAttribute static constexpr const char *pyClassName = "DenseFPElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - nb::float_ dunderGetItem(intptr_t pos) { + py::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds element"); + throw py::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. nb::float_ is implicitly constructible + // elemental type of the attribute. py::float_ is implicitly constructible // from float and double. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. if (mlirTypeIsAF32(type)) { - return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); + return mlirDenseElementsAttrGetFloatValue(*this, pos); } if (mlirTypeIsAF64(type)) { - return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); + return mlirDenseElementsAttrGetDoubleValue(*this, pos); } - throw nb::type_error("Unsupported floating-point type"); + throw py::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { @@ -1603,9 +1406,9 @@ class PyTypeAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, - nb::arg("value"), nb::arg("context").none() = nb::none(), + py::arg("value"), py::arg("context") = py::none(), "Gets a uniqued Type attribute"); - c.def_prop_ro("value", [](PyTypeAttribute &self) { + c.def_property_readonly("value", [](PyTypeAttribute &self) { return mlirTypeAttrGetValue(self.get()); }); } @@ -1627,7 +1430,7 @@ class PyUnitAttribute : public PyConcreteAttribute { return PyUnitAttribute(context->getRef(), mlirUnitAttrGet(context->get())); }, - nb::arg("context").none() = nb::none(), "Create a Unit attribute."); + py::arg("context") = py::none(), "Create a Unit attribute."); } }; @@ -1650,8 +1453,7 @@ class PyStridedLayoutAttribute ctx->get(), offset, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - nb::arg("offset"), nb::arg("strides"), - nb::arg("context").none() = nb::none(), + py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), "Gets a strided layout attribute."); c.def_static( "get_fully_dynamic", @@ -1663,17 +1465,16 @@ class PyStridedLayoutAttribute ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - nb::arg("rank"), nb::arg("context").none() = nb::none(), - "Gets a strided layout attribute with dynamic offset and strides of " - "a " + py::arg("rank"), py::arg("context") = py::none(), + "Gets a strided layout attribute with dynamic offset and strides of a " "given rank."); - c.def_prop_ro( + c.def_property_readonly( "offset", [](PyStridedLayoutAttribute &self) { return mlirStridedLayoutAttrGetOffset(self); }, "Returns the value of the float point attribute"); - c.def_prop_ro( + c.def_property_readonly( "strides", [](PyStridedLayoutAttribute &self) { intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); @@ -1687,64 +1488,63 @@ class PyStridedLayoutAttribute } }; -nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { +py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); + return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); + return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); + return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); + return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); + return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); + return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); + return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; - throw nb::type_error(msg.c_str()); + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); } -nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { +py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); + return py::cast(PyDenseFPElementsAttribute(pyAttribute)); if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) - return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); + return py::cast(PyDenseIntElementsAttribute(pyAttribute)); std::string msg = std::string( "Can't cast unknown element type DenseIntOrFPElementsAttr (") + - nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; - throw nb::type_error(msg.c_str()); + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); } -nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { +py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { if (PyBoolAttribute::isaFunction(pyAttribute)) - return nb::cast(PyBoolAttribute(pyAttribute)); + return py::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) - return nb::cast(PyIntegerAttribute(pyAttribute)); + return py::cast(PyIntegerAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; - throw nb::type_error(msg.c_str()); + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); } -nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { +py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) - return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); + return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); if (PySymbolRefAttribute::isaFunction(pyAttribute)) - return nb::cast(PySymbolRefAttribute(pyAttribute)); + return py::cast(PySymbolRefAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + - nb::cast(nb::repr(nb::cast(pyAttribute))) + - ")"; - throw nb::type_error(msg.c_str()); + std::string(py::repr(py::cast(pyAttribute))) + ")"; + throw py::cast_error(msg); } } // namespace -void mlir::python::populateIRAttributes(nb::module_ &m) { +void mlir::python::populateIRAttributes(py::module &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1762,26 +1562,24 @@ void mlir::python::populateIRAttributes(nb::module_ &m) { PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseArrayAttrGetTypeID(), - nb::cast(nb::cpp_function(denseArrayAttributeCaster))); + pybind11::cpp_function(denseArrayAttributeCaster)); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); + PyDenseElementsAttribute::bind(m); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), - nb::cast( - nb::cpp_function(denseIntOrFPElementsAttributeCaster))); + pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirSymbolRefAttrGetTypeID(), - nb::cast( - nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); + pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1792,7 +1590,7 @@ void mlir::python::populateIRAttributes(nb::module_ &m) { PyTypeAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirIntegerAttrGetTypeID(), - nb::cast(nb::cpp_function(integerOrBoolAttributeCaster))); + pybind11::cpp_function(integerOrBoolAttributeCaster)); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e1c56a398..3e96f8c60 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,31 +6,26 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include -#include - -#include -#include +#include "IRModule.h" #include "Globals.h" -#include "IRModule.h" -#include "NanobindUtils.h" +#include "PybindUtils.h" + #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -namespace nb = nanobind; -using namespace nb::literals; +#include +#include + +namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -195,18 +190,18 @@ operations. /// Helper for creating an @classmethod. template -nb::object classmethod(Func f, Args... args) { - nb::object cf = nb::cpp_function(f, args...); - return nb::borrow((PyClassMethod_New(cf.ptr()))); +py::object classmethod(Func f, Args... args) { + py::object cf = py::cpp_function(f, args...); + return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); } -static nb::object +static py::object createCustomDialectWrapper(const std::string &dialectNamespace, - nb::object dialectDescriptor) { + py::object dialectDescriptor) { auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); if (!dialectClass) { // Use the base class. - return nb::cast(PyDialect(std::move(dialectDescriptor))); + return py::cast(PyDialect(std::move(dialectDescriptor))); } // Create the custom implementation. @@ -217,47 +212,42 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } -static MlirStringRef toMlirStringRef(const nb::bytes &s) { - return mlirStringRefCreate(static_cast(s.data()), s.size()); -} - /// Create a block, using the current location context if no locations are /// specified. -static MlirBlock createBlock(const nb::sequence &pyArgTypes, - const std::optional &pyArgLocs) { +static MlirBlock createBlock(const py::sequence &pyArgTypes, + const std::optional &pyArgLocs) { SmallVector argTypes; - argTypes.reserve(nb::len(pyArgTypes)); + argTypes.reserve(pyArgTypes.size()); for (const auto &pyType : pyArgTypes) - argTypes.push_back(nb::cast(pyType)); + argTypes.push_back(pyType.cast()); SmallVector argLocs; if (pyArgLocs) { - argLocs.reserve(nb::len(*pyArgLocs)); + argLocs.reserve(pyArgLocs->size()); for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(nb::cast(pyLoc)); + argLocs.push_back(pyLoc.cast()); } else if (!argTypes.empty()) { argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); } if (argTypes.size() != argLocs.size()) - throw nb::value_error(("Expected " + Twine(argTypes.size()) + + throw py::value_error(("Expected " + Twine(argTypes.size()) + " locations, got: " + Twine(argLocs.size())) - .str() - .c_str()); + .str()); return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); } /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } - static void bind(nb::module_ &m) { + static void bind(py::module &m) { // Debug flags. - nb::class_(m, "_GlobalDebug") - .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + py::class_(m, "_GlobalDebug", py::module_local()) + .def_property_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") .def_static( "set_types", [](const std::string &type) { @@ -278,20 +268,20 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static nb::callable dundeGetItemNamed(const std::string &attributeKind) { + static py::function dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw nb::key_error(attributeKind.c_str()); + throw py::key_error(attributeKind); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + py::function func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } - static void bind(nb::module_ &m) { - nb::class_(m, "AttrBuilder") + static void bind(py::module &m) { + py::class_(m, "AttrBuilder", py::module_local()) .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, @@ -305,8 +295,8 @@ struct PyAttrBuilderMap { // PyBlock //------------------------------------------------------------------------------ -nb::object PyBlock::getCapsule() { - return nb::steal(mlirPythonBlockToCapsule(get())); +py::object PyBlock::getCapsule() { + return py::reinterpret_steal(mlirPythonBlockToCapsule(get())); } //------------------------------------------------------------------------------ @@ -325,14 +315,14 @@ class PyRegionIterator { PyRegion dunderNext() { operation->checkValid(); if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw nb::stop_iteration(); + throw py::stop_iteration(); } MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); return PyRegion(operation, region); } - static void bind(nb::module_ &m) { - nb::class_(m, "RegionIterator") + static void bind(py::module &m) { + py::class_(m, "RegionIterator", py::module_local()) .def("__iter__", &PyRegionIterator::dunderIter) .def("__next__", &PyRegionIterator::dunderNext); } @@ -361,14 +351,14 @@ class PyRegionList { PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds region"); + throw py::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); } - static void bind(nb::module_ &m) { - nb::class_(m, "RegionSequence") + static void bind(py::module &m) { + py::class_(m, "RegionSequence", py::module_local()) .def("__len__", &PyRegionList::dunderLen) .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); @@ -388,7 +378,7 @@ class PyBlockIterator { PyBlock dunderNext() { operation->checkValid(); if (mlirBlockIsNull(next)) { - throw nb::stop_iteration(); + throw py::stop_iteration(); } PyBlock returnBlock(operation, next); @@ -396,8 +386,8 @@ class PyBlockIterator { return returnBlock; } - static void bind(nb::module_ &m) { - nb::class_(m, "BlockIterator") + static void bind(py::module &m) { + py::class_(m, "BlockIterator", py::module_local()) .def("__iter__", &PyBlockIterator::dunderIter) .def("__next__", &PyBlockIterator::dunderNext); } @@ -434,7 +424,7 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw nb::index_error("attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -444,26 +434,24 @@ class PyBlockList { block = mlirBlockGetNextInRegion(block); index -= 1; } - throw nb::index_error("attempt to access out of bounds block"); + throw py::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const nb::args &pyArgTypes, - const std::optional &pyArgLocs) { + PyBlock appendBlock(const py::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - MlirBlock block = - createBlock(nb::cast(pyArgTypes), pyArgLocs); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } - static void bind(nb::module_ &m) { - nb::class_(m, "BlockList") + static void bind(py::module &m) { + py::class_(m, "BlockList", py::module_local()) .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, - nb::arg("args"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt); + py::arg("arg_locs") = std::nullopt); } private: @@ -478,10 +466,10 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - nb::object dunderNext() { + py::object dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { - throw nb::stop_iteration(); + throw py::stop_iteration(); } PyOperationRef returnOperation = @@ -490,8 +478,8 @@ class PyOperationIterator { return returnOperation->createOpView(); } - static void bind(nb::module_ &m) { - nb::class_(m, "OperationIterator") + static void bind(py::module &m) { + py::class_(m, "OperationIterator", py::module_local()) .def("__iter__", &PyOperationIterator::dunderIter) .def("__next__", &PyOperationIterator::dunderNext); } @@ -527,10 +515,10 @@ class PyOperationList { return count; } - nb::object dunderGetItem(intptr_t index) { + py::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw nb::index_error("attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -541,11 +529,11 @@ class PyOperationList { childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw nb::index_error("attempt to access out of bounds operation"); + throw py::index_error("attempt to access out of bounds operation"); } - static void bind(nb::module_ &m) { - nb::class_(m, "OperationList") + static void bind(py::module &m) { + py::class_(m, "OperationList", py::module_local()) .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) .def("__len__", &PyOperationList::dunderLen); @@ -560,7 +548,7 @@ class PyOpOperand { public: PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - nb::object getOwner() { + py::object getOwner() { MlirOperation owner = mlirOpOperandGetOwner(opOperand); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(owner)); @@ -569,10 +557,11 @@ class PyOpOperand { size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } - static void bind(nb::module_ &m) { - nb::class_(m, "OpOperand") - .def_prop_ro("owner", &PyOpOperand::getOwner) - .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); + static void bind(py::module &m) { + py::class_(m, "OpOperand", py::module_local()) + .def_property_readonly("owner", &PyOpOperand::getOwner) + .def_property_readonly("operand_number", + &PyOpOperand::getOperandNumber); } private: @@ -587,15 +576,15 @@ class PyOpOperandIterator { PyOpOperand dunderNext() { if (mlirOpOperandIsNull(opOperand)) - throw nb::stop_iteration(); + throw py::stop_iteration(); PyOpOperand returnOpOperand(opOperand); opOperand = mlirOpOperandGetNextUse(opOperand); return returnOpOperand; } - static void bind(nb::module_ &m) { - nb::class_(m, "OpOperandIterator") + static void bind(py::module &m) { + py::class_(m, "OpOperandIterator", py::module_local()) .def("__iter__", &PyOpOperandIterator::dunderIter) .def("__next__", &PyOpOperandIterator::dunderNext); } @@ -611,7 +600,7 @@ class PyOpOperandIterator { //------------------------------------------------------------------------------ PyMlirContext::PyMlirContext(MlirContext context) : context(context) { - nb::gil_scoped_acquire acquire; + py::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -620,36 +609,41 @@ PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into // liveContexts. - nb::gil_scoped_acquire acquire; + py::gil_scoped_acquire acquire; getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } -nb::object PyMlirContext::getCapsule() { - return nb::steal(mlirPythonContextToCapsule(get())); +py::object PyMlirContext::getCapsule() { + return py::reinterpret_steal(mlirPythonContextToCapsule(get())); } -nb::object PyMlirContext::createFromCapsule(nb::object capsule) { +py::object PyMlirContext::createFromCapsule(py::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) - throw nb::python_error(); + throw py::error_already_set(); return forContext(rawContext).releaseObject(); } +PyMlirContext *PyMlirContext::createNewContextForInit() { + MlirContext context = mlirContextCreateWithThreading(false); + return new PyMlirContext(context); +} + PyMlirContextRef PyMlirContext::forContext(MlirContext context) { - nb::gil_scoped_acquire acquire; + py::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - nb::object pyRef = nb::cast(unownedContextWrapper); - assert(pyRef && "cast to nb::object failed"); + py::object pyRef = py::cast(unownedContextWrapper); + assert(pyRef && "cast to py::object failed"); liveContexts[context.ptr] = unownedContextWrapper; return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } // Use existing. - nb::object pyRef = nb::cast(it->second); + py::object pyRef = py::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } @@ -723,23 +717,23 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } -nb::object PyMlirContext::contextEnter(nb::object context) { - return PyThreadContextEntry::pushContext(context); +pybind11::object PyMlirContext::contextEnter() { + return PyThreadContextEntry::pushContext(*this); } -void PyMlirContext::contextExit(const nb::object &excType, - const nb::object &excVal, - const nb::object &excTb) { +void PyMlirContext::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popContext(*this); } -nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { +py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { // Note that ownership is transferred to the delete callback below by way of // an explicit inc_ref (borrow). PyDiagnosticHandler *pyHandler = new PyDiagnosticHandler(get(), std::move(callback)); - nb::object pyHandlerObject = - nb::cast(pyHandler, nb::rv_policy::take_ownership); + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::take_ownership); pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is @@ -747,17 +741,17 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { auto handlerCallback = +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); - nb::object pyDiagnosticObject = - nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); + py::object pyDiagnosticObject = + py::cast(pyDiagnostic, py::return_value_policy::take_ownership); auto *pyHandler = static_cast(userData); bool result = false; { // Since this can be called from arbitrary C++ contexts, always get the // gil. - nb::gil_scoped_acquire gil; + py::gil_scoped_acquire gil; try { - result = nb::cast(pyHandler->callback(pyDiagnostic)); + result = py::cast(pyHandler->callback(pyDiagnostic)); } catch (std::exception &e) { fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", e.what()); @@ -774,7 +768,8 @@ nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { pyHandler->registeredID.reset(); // Decrement reference, balancing the inc_ref() above. - nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); + py::object pyHandlerObject = + py::cast(pyHandler, py::return_value_policy::reference); pyHandlerObject.dec_ref(); }; @@ -824,9 +819,9 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { return &stack.back(); } -void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, - nb::object insertionPoint, - nb::object location) { +void PyThreadContextEntry::push(FrameKind frameKind, py::object context, + py::object insertionPoint, + py::object location) { auto &stack = getStack(); stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), std::move(location)); @@ -849,19 +844,19 @@ void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, PyMlirContext *PyThreadContextEntry::getContext() { if (!context) return nullptr; - return nb::cast(context); + return py::cast(context); } PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { if (!insertionPoint) return nullptr; - return nb::cast(insertionPoint); + return py::cast(insertionPoint); } PyLocation *PyThreadContextEntry::getLocation() { if (!location) return nullptr; - return nb::cast(location); + return py::cast(location); } PyMlirContext *PyThreadContextEntry::getDefaultContext() { @@ -879,11 +874,12 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() { return tos ? tos->getLocation() : nullptr; } -nb::object PyThreadContextEntry::pushContext(nb::object context) { - push(FrameKind::Context, /*context=*/context, - /*insertionPoint=*/nb::object(), - /*location=*/nb::object()); - return context; +py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { + py::object contextObj = py::cast(context); + push(FrameKind::Context, /*context=*/contextObj, + /*insertionPoint=*/py::object(), + /*location=*/py::object()); + return contextObj; } void PyThreadContextEntry::popContext(PyMlirContext &context) { @@ -896,16 +892,15 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) { stack.pop_back(); } -nb::object -PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { - PyInsertionPoint &insertionPoint = - nb::cast(insertionPointObj); - nb::object contextObj = +py::object +PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { + py::object contextObj = insertionPoint.getBlock().getParentOperation()->getContext().getObject(); + py::object insertionPointObj = py::cast(insertionPoint); push(FrameKind::InsertionPoint, /*context=*/contextObj, /*insertionPoint=*/insertionPointObj, - /*location=*/nb::object()); + /*location=*/py::object()); return insertionPointObj; } @@ -920,11 +915,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { stack.pop_back(); } -nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { - PyLocation &location = nb::cast(locationObj); - nb::object contextObj = location.getContext().getObject(); +py::object PyThreadContextEntry::pushLocation(PyLocation &location) { + py::object contextObj = location.getContext().getObject(); + py::object locationObj = py::cast(location); push(FrameKind::Location, /*context=*/contextObj, - /*insertionPoint=*/nb::object(), + /*insertionPoint=*/py::object(), /*location=*/locationObj); return locationObj; } @@ -946,15 +941,15 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { void PyDiagnostic::invalidate() { valid = false; if (materializedNotes) { - for (nb::handle noteObject : *materializedNotes) { - PyDiagnostic *note = nb::cast(noteObject); + for (auto ¬eObject : *materializedNotes) { + PyDiagnostic *note = py::cast(noteObject); note->invalidate(); } } } PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, - nb::object callback) + py::object callback) : context(context), callback(std::move(callback)) {} PyDiagnosticHandler::~PyDiagnosticHandler() = default; @@ -989,36 +984,32 @@ PyLocation PyDiagnostic::getLocation() { return PyLocation(PyMlirContext::forContext(context), loc); } -nb::str PyDiagnostic::getMessage() { +py::str PyDiagnostic::getMessage() { checkValid(); - nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); + py::object fileObject = py::module::import("io").attr("StringIO")(); PyFileAccumulator accum(fileObject, /*binary=*/false); mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); - return nb::cast(fileObject.attr("getvalue")()); + return fileObject.attr("getvalue")(); } -nb::tuple PyDiagnostic::getNotes() { +py::tuple PyDiagnostic::getNotes() { checkValid(); if (materializedNotes) return *materializedNotes; intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); - nb::tuple notes = nb::steal(PyTuple_New(numNotes)); + materializedNotes = py::tuple(numNotes); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); - PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); + (*materializedNotes)[i] = PyDiagnostic(noteDiag); } - materializedNotes = std::move(notes); - return *materializedNotes; } PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { std::vector notes; - for (nb::handle n : getNotes()) - notes.emplace_back(nb::cast(n).getInfo()); - return {getSeverity(), getLocation(), nb::cast(getMessage()), - std::move(notes)}; + for (py::handle n : getNotes()) + notes.emplace_back(n.cast().getInfo()); + return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; } //------------------------------------------------------------------------------ @@ -1032,21 +1023,22 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, if (mlirDialectIsNull(dialect)) { std::string msg = (Twine("Dialect '") + key + "' not found").str(); if (attrError) - throw nb::attribute_error(msg.c_str()); - throw nb::index_error(msg.c_str()); + throw py::attribute_error(msg); + throw py::index_error(msg); } return dialect; } -nb::object PyDialectRegistry::getCapsule() { - return nb::steal(mlirPythonDialectRegistryToCapsule(*this)); +py::object PyDialectRegistry::getCapsule() { + return py::reinterpret_steal( + mlirPythonDialectRegistryToCapsule(*this)); } -PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { +PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { MlirDialectRegistry rawRegistry = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); if (mlirDialectRegistryIsNull(rawRegistry)) - throw nb::python_error(); + throw py::error_already_set(); return PyDialectRegistry(rawRegistry); } @@ -1054,25 +1046,25 @@ PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { // PyLocation //------------------------------------------------------------------------------ -nb::object PyLocation::getCapsule() { - return nb::steal(mlirPythonLocationToCapsule(*this)); +py::object PyLocation::getCapsule() { + return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); } -PyLocation PyLocation::createFromCapsule(nb::object capsule) { +PyLocation PyLocation::createFromCapsule(py::object capsule) { MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); if (mlirLocationIsNull(rawLoc)) - throw nb::python_error(); + throw py::error_already_set(); return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), rawLoc); } -nb::object PyLocation::contextEnter(nb::object locationObj) { - return PyThreadContextEntry::pushLocation(locationObj); +py::object PyLocation::contextEnter() { + return PyThreadContextEntry::pushLocation(*this); } -void PyLocation::contextExit(const nb::object &excType, - const nb::object &excVal, - const nb::object &excTb) { +void PyLocation::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popLocation(*this); } @@ -1095,7 +1087,7 @@ PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} PyModule::~PyModule() { - nb::gil_scoped_acquire acquire; + py::gil_scoped_acquire acquire; auto &liveModules = getContext()->liveModules; assert(liveModules.count(module.ptr) == 1 && "destroying module not in live map"); @@ -1107,7 +1099,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - nb::gil_scoped_acquire acquire; + py::gil_scoped_acquire acquire; auto &liveModules = contextRef->liveModules; auto it = liveModules.find(module.ptr); if (it == liveModules.end()) { @@ -1116,7 +1108,8 @@ PyModuleRef PyModule::forModule(MlirModule module) { // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + py::object pyRef = + py::cast(unownedModule, py::return_value_policy::take_ownership); unownedModule->handle = pyRef; liveModules[module.ptr] = std::make_pair(unownedModule->handle, unownedModule); @@ -1124,19 +1117,19 @@ PyModuleRef PyModule::forModule(MlirModule module) { } // Use existing. PyModule *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); + py::object pyRef = py::reinterpret_borrow(it->second.first); return PyModuleRef(existing, std::move(pyRef)); } -nb::object PyModule::createFromCapsule(nb::object capsule) { +py::object PyModule::createFromCapsule(py::object capsule) { MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); if (mlirModuleIsNull(rawModule)) - throw nb::python_error(); + throw py::error_already_set(); return forModule(rawModule).releaseObject(); } -nb::object PyModule::getCapsule() { - return nb::steal(mlirPythonModuleToCapsule(get())); +py::object PyModule::getCapsule() { + return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); } //------------------------------------------------------------------------------ @@ -1165,7 +1158,7 @@ PyOperation::~PyOperation() { PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, - nb::object parentKeepAlive) { + py::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; // Create. PyOperation *unownedOperation = @@ -1173,7 +1166,8 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership); + py::object pyRef = + py::cast(unownedOperation, py::return_value_policy::take_ownership); unownedOperation->handle = pyRef; if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); @@ -1184,7 +1178,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, - nb::object parentKeepAlive) { + py::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); if (it == liveOperations.end()) { @@ -1194,13 +1188,13 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, } // Use existing. PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); + py::object pyRef = py::reinterpret_borrow(it->second.first); return PyOperationRef(existing, std::move(pyRef)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, - nb::object parentKeepAlive) { + py::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; assert(liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"); @@ -1233,12 +1227,12 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nb::object fileObject, + bool assumeVerified, py::object fileObject, bool binary, bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = nb::module_::import_("sys").attr("stdout"); + fileObject = py::module::import("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) @@ -1261,18 +1255,18 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::print(PyAsmState &state, nb::object fileObject, +void PyOperationBase::print(PyAsmState &state, py::object fileObject, bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = nb::module_::import_("sys").attr("stdout"); + fileObject = py::module::import("sys").attr("stdout"); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), accum.getUserData()); } -void PyOperationBase::writeBytecode(const nb::object &fileObject, +void PyOperationBase::writeBytecode(const py::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); @@ -1288,10 +1282,9 @@ void PyOperationBase::writeBytecode(const nb::object &fileObject, operation, config, accum.getCallback(), accum.getUserData()); mlirBytecodeWriterConfigDestroy(config); if (mlirLogicalResultIsFailure(res)) - throw nb::value_error((Twine("Unable to honor desired bytecode version ") + + throw py::value_error((Twine("Unable to honor desired bytecode version ") + Twine(*bytecodeVersion)) - .str() - .c_str()); + .str()); } void PyOperationBase::walk( @@ -1303,7 +1296,7 @@ void PyOperationBase::walk( std::function callback; bool gotException; std::string exceptionWhat; - nb::object exceptionType; + py::object exceptionType; }; UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, @@ -1311,10 +1304,10 @@ void PyOperationBase::walk( UserData *calleeUserData = static_cast(userData); try { return (calleeUserData->callback)(op); - } catch (nb::python_error &e) { + } catch (py::error_already_set &e) { calleeUserData->gotException = true; - calleeUserData->exceptionWhat = std::string(e.what()); - calleeUserData->exceptionType = nb::borrow(e.type()); + calleeUserData->exceptionWhat = e.what(); + calleeUserData->exceptionType = e.type(); return MlirWalkResult::MlirWalkResultInterrupt; } }; @@ -1326,16 +1319,16 @@ void PyOperationBase::walk( } } -nb::object PyOperationBase::getAsm(bool binary, +py::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions) { - nb::object fileObject; + py::object fileObject; if (binary) { - fileObject = nb::module_::import_("io").attr("BytesIO")(); + fileObject = py::module::import("io").attr("BytesIO")(); } else { - fileObject = nb::module_::import_("io").attr("StringIO")(); + fileObject = py::module::import("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, @@ -1379,7 +1372,7 @@ bool PyOperationBase::verify() { std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw nb::value_error("Detached operations have no parent"); + throw py::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1395,42 +1388,42 @@ PyBlock PyOperation::getBlock() { return PyBlock{std::move(*parentOperation), block}; } -nb::object PyOperation::getCapsule() { +py::object PyOperation::getCapsule() { checkValid(); - return nb::steal(mlirPythonOperationToCapsule(get())); + return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); } -nb::object PyOperation::createFromCapsule(nb::object capsule) { +py::object PyOperation::createFromCapsule(py::object capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) - throw nb::python_error(); + throw py::error_already_set(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) .releaseObject(); } static void maybeInsertOperation(PyOperationRef &op, - const nb::object &maybeIp) { + const py::object &maybeIp) { // InsertPoint active? - if (!maybeIp.is(nb::cast(false))) { + if (!maybeIp.is(py::cast(false))) { PyInsertionPoint *ip; if (maybeIp.is_none()) { ip = PyThreadContextEntry::getDefaultInsertionPoint(); } else { - ip = nb::cast(maybeIp); + ip = py::cast(maybeIp); } if (ip) ip->insert(*op.get()); } } -nb::object PyOperation::create(const std::string &name, +py::object PyOperation::create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const nb::object &maybeIp, bool inferType) { + const py::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1438,14 +1431,14 @@ nb::object PyOperation::create(const std::string &name, // General parameter validation. if (regions < 0) - throw nb::value_error("number of regions must be >= 0"); + throw py::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw nb::value_error("operand value cannot be None"); + throw py::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1456,38 +1449,38 @@ nb::object PyOperation::create(const std::string &name, for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw nb::value_error("result type cannot be None"); + throw py::value_error("result type cannot be None"); mlirResults.push_back(*result); } } // Unpack/validate attributes. if (attributes) { mlirAttributes.reserve(attributes->size()); - for (std::pair it : *attributes) { + for (auto &it : *attributes) { std::string key; try { - key = nb::cast(it.first); - } catch (nb::cast_error &err) { + key = it.first.cast(); + } catch (py::cast_error &err) { std::string msg = "Invalid attribute key (not a string) when " "attempting to create the operation \"" + name + "\" (" + err.what() + ")"; - throw nb::type_error(msg.c_str()); + throw py::cast_error(msg); } try { - auto &attribute = nb::cast(it.second); + auto &attribute = it.second.cast(); // TODO: Verify attribute originates from the same context. mlirAttributes.emplace_back(std::move(key), attribute); - } catch (nb::cast_error &err) { - std::string msg = "Invalid attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw nb::type_error(msg.c_str()); - } catch (std::runtime_error &) { + } catch (py::reference_cast_error &) { // This exception seems thrown when the value is "None". std::string msg = "Found an invalid (`None`?) attribute value for the key \"" + key + "\" when attempting to create the operation \"" + name + "\""; - throw std::runtime_error(msg); + throw py::cast_error(msg); + } catch (py::cast_error &err) { + std::string msg = "Invalid attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + name + "\" (" + err.what() + ")"; + throw py::cast_error(msg); } } } @@ -1497,7 +1490,7 @@ nb::object PyOperation::create(const std::string &name, for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw nb::value_error("successor block cannot be None"); + throw py::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1542,7 +1535,7 @@ nb::object PyOperation::create(const std::string &name, // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw nb::value_error("Operation creation failed"); + throw py::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1550,7 +1543,7 @@ nb::object PyOperation::create(const std::string &name, return created.getObject(); } -nb::object PyOperation::clone(const nb::object &maybeIp) { +py::object PyOperation::clone(const py::object &maybeIp) { MlirOperation clonedOperation = mlirOperationClone(operation); PyOperationRef cloned = PyOperation::createDetached(getContext(), clonedOperation); @@ -1559,15 +1552,15 @@ nb::object PyOperation::clone(const nb::object &maybeIp) { return cloned->createOpView(); } -nb::object PyOperation::createOpView() { +py::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); if (operationCls) - return PyOpView::constructDerived(*operationCls, getRef().getObject()); - return nb::cast(PyOpView(getRef().getObject())); + return PyOpView::constructDerived(*operationCls, *getRef().get()); + return py::cast(PyOpView(getRef().getObject())); } void PyOperation::erase() { @@ -1580,8 +1573,8 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -static void populateResultTypes(StringRef name, nb::list resultTypeList, - const nb::object &resultSegmentSpecObj, +static void populateResultTypes(StringRef name, py::list resultTypeList, + const py::object &resultSegmentSpecObj, std::vector &resultSegmentLengths, std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); @@ -1589,28 +1582,26 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, // Non-variadic result unpacking. for (const auto &it : llvm::enumerate(resultTypeList)) { try { - resultTypes.push_back(nb::cast(it.value())); + resultTypes.push_back(py::cast(it.value())); if (!resultTypes.back()) - throw nb::cast_error(); - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Result ") + + throw py::cast_error(); + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str() - .c_str()); + .str()); } } } else { // Sized result unpacking. - auto resultSegmentSpec = nb::cast>(resultSegmentSpecObj); + auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); if (resultSegmentSpec.size() != resultTypeList.size()) { - throw nb::value_error((llvm::Twine("Operation \"") + name + + throw py::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + " result segments but was provided " + llvm::Twine(resultTypeList.size())) - .str() - .c_str()); + .str()); } resultSegmentLengths.reserve(resultTypeList.size()); for (const auto &it : @@ -1619,7 +1610,7 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *resultType = nb::cast(std::get<0>(it.value())); + auto *resultType = py::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); @@ -1627,20 +1618,14 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, // Allowed to be optional. resultSegmentLengths.push_back(0); } else { - throw nb::value_error( - (llvm::Twine("Result ") + llvm::Twine(it.index()) + - " of operation \"" + name + - "\" must be a Type (was None and result is not optional)") - .str() - .c_str()); + throw py::cast_error("was None and result is not optional"); } - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Result ") + + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str() - .c_str()); + .str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1650,75 +1635,72 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, resultSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = nb::cast(std::get<0>(it.value())); - for (nb::handle segmentItem : segment) { - resultTypes.push_back(nb::cast(segmentItem)); + auto segment = py::cast(std::get<0>(it.value())); + for (py::object segmentItem : segment) { + resultTypes.push_back(py::cast(segmentItem)); if (!resultTypes.back()) { - throw nb::type_error("contained a None item"); + throw py::cast_error("contained a None item"); } } - resultSegmentLengths.push_back(nb::len(segment)); + resultSegmentLengths.push_back(segment.size()); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw nb::value_error((llvm::Twine("Result ") + + throw py::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Types (" + err.what() + ")") - .str() - .c_str()); + .str()); } } else { - throw nb::value_error("Unexpected segment spec"); + throw py::value_error("Unexpected segment spec"); } } } } -nb::object PyOpView::buildGeneric( - const nb::object &cls, std::optional resultTypeList, - nb::list operandList, std::optional attributes, +py::object PyOpView::buildGeneric( + const py::object &cls, std::optional resultTypeList, + py::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const nb::object &maybeIp) { + const py::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. - std::string name = nb::cast(cls.attr("OPERATION_NAME")); + std::string name = py::cast(cls.attr("OPERATION_NAME")); // Operand and result segment specs are either none, which does no // variadic unpacking, or a list of ints with segment sizes, where each // element is either a positive number (typically 1 for a scalar) or -1 to // indicate that it is derived from the length of the same-indexed operand // or result (implying that it is a list at that position). - nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); std::vector operandSegmentLengths; std::vector resultSegmentLengths; // Validate/determine region count. - auto opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); + auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); int opMinRegionCount = std::get<0>(opRegionSpec); bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); if (!regions) { regions = opMinRegionCount; } if (*regions < opMinRegionCount) { - throw nb::value_error( + throw py::value_error( (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str() - .c_str()); + .str()); } if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw nb::value_error( + throw py::value_error( (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str() - .c_str()); + .str()); } // Unpack results. @@ -1735,28 +1717,26 @@ nb::object PyOpView::buildGeneric( // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(nb::cast(it.value())); + operands.push_back(py::cast(it.value())); if (!operands.back()) - throw nb::cast_error(); - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Operand ") + + throw py::cast_error(); + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str() - .c_str()); + .str()); } } } else { // Sized operand unpacking. - auto operandSegmentSpec = nb::cast>(operandSegmentSpecObj); + auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); if (operandSegmentSpec.size() != operandList.size()) { - throw nb::value_error((llvm::Twine("Operation \"") + name + + throw py::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(operandSegmentSpec.size()) + "operand segments but was provided " + llvm::Twine(operandList.size())) - .str() - .c_str()); + .str()); } operandSegmentLengths.reserve(operandList.size()); for (const auto &it : @@ -1765,7 +1745,7 @@ nb::object PyOpView::buildGeneric( if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *operandValue = nb::cast(std::get<0>(it.value())); + auto *operandValue = py::cast(std::get<0>(it.value())); if (operandValue) { operands.push_back(operandValue); operandSegmentLengths.push_back(1); @@ -1773,20 +1753,14 @@ nb::object PyOpView::buildGeneric( // Allowed to be optional. operandSegmentLengths.push_back(0); } else { - throw nb::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " of operation \"" + name + - "\" must be a Value (was None and operand is not optional)") - .str() - .c_str()); + throw py::cast_error("was None and operand is not optional"); } - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Operand ") + + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str() - .c_str()); + .str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1796,28 +1770,27 @@ nb::object PyOpView::buildGeneric( operandSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = nb::cast(std::get<0>(it.value())); - for (nb::handle segmentItem : segment) { - operands.push_back(nb::cast(segmentItem)); + auto segment = py::cast(std::get<0>(it.value())); + for (py::object segmentItem : segment) { + operands.push_back(py::cast(segmentItem)); if (!operands.back()) { - throw nb::type_error("contained a None item"); + throw py::cast_error("contained a None item"); } } - operandSegmentLengths.push_back(nb::len(segment)); + operandSegmentLengths.push_back(segment.size()); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw nb::value_error((llvm::Twine("Operand ") + + throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Values (" + err.what() + ")") - .str() - .c_str()); + .str()); } } else { - throw nb::value_error("Unexpected segment spec"); + throw py::value_error("Unexpected segment spec"); } } } @@ -1826,13 +1799,13 @@ nb::object PyOpView::buildGeneric( if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { // Dup. if (attributes) { - attributes = nb::dict(*attributes); + attributes = py::dict(*attributes); } else { - attributes = nb::dict(); + attributes = py::dict(); } if (attributes->contains("resultSegmentSizes") || attributes->contains("operandSegmentSizes")) { - throw nb::value_error("Manually setting a 'resultSegmentSizes' or " + throw py::value_error("Manually setting a 'resultSegmentSizes' or " "'operandSegmentSizes' attribute is unsupported. " "Use Operation.create for such low-level access."); } @@ -1866,18 +1839,21 @@ nb::object PyOpView::buildGeneric( !resultTypeList); } -nb::object PyOpView::constructDerived(const nb::object &cls, - const nb::object &operation) { - nb::handle opViewType = nb::type(); - nb::object instance = cls.attr("__new__")(cls); +pybind11::object PyOpView::constructDerived(const pybind11::object &cls, + const PyOperation &operation) { + // TODO: pybind11 2.6 supports a more direct form. + // Upgrade many years from now. + // auto opViewType = py::type::of(); + py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); + py::object instance = cls.attr("__new__")(cls); opViewType.attr("__init__")(instance, operation); return instance; } -PyOpView::PyOpView(const nb::object &operationObject) +PyOpView::PyOpView(const py::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. - : operation(nb::cast(operationObject).getOperation()), + : operation(py::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} //------------------------------------------------------------------------------ @@ -1893,7 +1869,7 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw nb::value_error( + throw py::value_error( "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; @@ -1906,7 +1882,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { // already end in a known terminator (violating this will cause assertion // failures later). if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { - throw nb::index_error("Cannot insert operation at the end of a block " + throw py::index_error("Cannot insert operation at the end of a block " "that already has a terminator. Did you mean to " "use 'InsertionPoint.at_block_terminator(block)' " "versus 'InsertionPoint(block)'?"); @@ -1932,19 +1908,19 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw nb::value_error("Block has no terminator"); + throw py::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; } -nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { - return PyThreadContextEntry::pushInsertionPoint(insertPoint); +py::object PyInsertionPoint::contextEnter() { + return PyThreadContextEntry::pushInsertionPoint(*this); } -void PyInsertionPoint::contextExit(const nb::object &excType, - const nb::object &excVal, - const nb::object &excTb) { +void PyInsertionPoint::contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { PyThreadContextEntry::popInsertionPoint(*this); } @@ -1956,14 +1932,14 @@ bool PyAttribute::operator==(const PyAttribute &other) const { return mlirAttributeEqual(attr, other.attr); } -nb::object PyAttribute::getCapsule() { - return nb::steal(mlirPythonAttributeToCapsule(*this)); +py::object PyAttribute::getCapsule() { + return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { +PyAttribute PyAttribute::createFromCapsule(py::object capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) - throw nb::python_error(); + throw py::error_already_set(); return PyAttribute( PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); } @@ -1988,14 +1964,14 @@ bool PyType::operator==(const PyType &other) const { return mlirTypeEqual(type, other.type); } -nb::object PyType::getCapsule() { - return nb::steal(mlirPythonTypeToCapsule(*this)); +py::object PyType::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); } -PyType PyType::createFromCapsule(nb::object capsule) { +PyType PyType::createFromCapsule(py::object capsule) { MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); if (mlirTypeIsNull(rawType)) - throw nb::python_error(); + throw py::error_already_set(); return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), rawType); } @@ -2004,14 +1980,14 @@ PyType PyType::createFromCapsule(nb::object capsule) { // PyTypeID. //------------------------------------------------------------------------------ -nb::object PyTypeID::getCapsule() { - return nb::steal(mlirPythonTypeIDToCapsule(*this)); +py::object PyTypeID::getCapsule() { + return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); } -PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { +PyTypeID PyTypeID::createFromCapsule(py::object capsule) { MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); if (mlirTypeIDIsNull(mlirTypeID)) - throw nb::python_error(); + throw py::error_already_set(); return PyTypeID(mlirTypeID); } bool PyTypeID::operator==(const PyTypeID &other) const { @@ -2022,36 +1998,36 @@ bool PyTypeID::operator==(const PyTypeID &other) const { // PyValue and subclasses. //------------------------------------------------------------------------------ -nb::object PyValue::getCapsule() { - return nb::steal(mlirPythonValueToCapsule(get())); +pybind11::object PyValue::getCapsule() { + return py::reinterpret_steal(mlirPythonValueToCapsule(get())); } -nb::object PyValue::maybeDownCast() { +pybind11::object PyValue::maybeDownCast() { MlirType type = mlirValueGetType(get()); MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional valueCaster = + std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); - // nb::rv_policy::move means use std::move to move the return value + // py::return_value_policy::move means use std::move to move the return value // contents into a new instance that will be owned by Python. - nb::object thisObj = nb::cast(this, nb::rv_policy::move); + py::object thisObj = py::cast(this, py::return_value_policy::move); if (!valueCaster) return thisObj; return valueCaster.value()(thisObj); } -PyValue PyValue::createFromCapsule(nb::object capsule) { +PyValue PyValue::createFromCapsule(pybind11::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) - throw nb::python_error(); + throw py::error_already_set(); MlirOperation owner; if (mlirValueIsAOpResult(value)) owner = mlirOpResultGetOwner(value); if (mlirValueIsABlockArgument(value)) owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); if (mlirOperationIsNull(owner)) - throw nb::python_error(); + throw py::error_already_set(); MlirContext ctx = mlirOperationGetContext(owner); PyOperationRef ownerRef = PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); @@ -2066,17 +2042,16 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation) : operation(operation.getOperation().getRef()) { symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); if (mlirSymbolTableIsNull(symbolTable)) { - throw nb::type_error("Operation is not a Symbol Table."); + throw py::cast_error("Operation is not a Symbol Table."); } } -nb::object PySymbolTable::dunderGetItem(const std::string &name) { +py::object PySymbolTable::dunderGetItem(const std::string &name) { operation->checkValid(); MlirOperation symbol = mlirSymbolTableLookup( symbolTable, mlirStringRefCreate(name.data(), name.length())); if (mlirOperationIsNull(symbol)) - throw nb::key_error( - ("Symbol '" + name + "' not in the symbol table.").c_str()); + throw py::key_error("Symbol '" + name + "' not in the symbol table."); return PyOperation::forOperation(operation->getContext(), symbol, operation.getObject()) @@ -2094,8 +2069,8 @@ void PySymbolTable::erase(PyOperationBase &symbol) { } void PySymbolTable::dunderDel(const std::string &name) { - nb::object operation = dunderGetItem(name); - erase(nb::cast(operation)); + py::object operation = dunderGetItem(name); + erase(py::cast(operation)); } MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { @@ -2104,7 +2079,7 @@ MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) - throw nb::value_error("Expected operation to have a symbol name."); + throw py::value_error("Expected operation to have a symbol name."); return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } @@ -2116,7 +2091,7 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw nb::value_error("Expected operation to have a symbol name."); + throw py::value_error("Expected operation to have a symbol name."); return existingNameAttr; } @@ -2129,7 +2104,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw nb::value_error("Expected operation to have a symbol name."); + throw py::value_error("Expected operation to have a symbol name."); MlirAttribute newNameAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); @@ -2142,7 +2117,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw nb::value_error("Expected operation to have a symbol visibility."); + throw py::value_error("Expected operation to have a symbol visibility."); return existingVisAttr; } @@ -2150,7 +2125,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, const std::string &visibility) { if (visibility != "public" && visibility != "private" && visibility != "nested") - throw nb::value_error( + throw py::value_error( "Expected visibility to be 'public', 'private' or 'nested'"); PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -2158,7 +2133,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw nb::value_error("Expected operation to have a symbol visibility."); + throw py::value_error("Expected operation to have a symbol visibility."); MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(visibility)); mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); @@ -2173,20 +2148,20 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), from.getOperation()))) - throw nb::value_error("Symbol rename failed"); + throw py::value_error("Symbol rename failed"); } void PySymbolTable::walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - nb::object callback) { + py::object callback) { PyOperation &fromOperation = from.getOperation(); fromOperation.checkValid(); struct UserData { PyMlirContextRef context; - nb::object callback; + py::object callback; bool gotException; std::string exceptionWhat; - nb::object exceptionType; + py::object exceptionType; }; UserData userData{ fromOperation.getContext(), std::move(callback), false, {}, {}}; @@ -2200,10 +2175,10 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, return; try { calleeUserData->callback(pyFoundOp.getObject(), isVisible); - } catch (nb::python_error &e) { + } catch (py::error_already_set &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = nb::borrow(e.type()); + calleeUserData->exceptionType = e.type(); } }, static_cast(&userData)); @@ -2225,7 +2200,7 @@ class PyConcreteValue : public PyValue { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = nb::class_; + using ClassTy = py::class_; using IsAFunctionTy = bool (*)(MlirValue); PyConcreteValue() = default; @@ -2238,26 +2213,25 @@ class PyConcreteValue : public PyValue { /// type mismatches. static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast value to ") + + auto origRepr = py::repr(py::cast(orig)).cast(); + throw py::value_error((Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str() - .c_str()); + .str()); } return orig.get(); } /// Binds the Python module objects to functions of this class. - static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); + static void bind(py::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); + cls.def(py::init(), py::keep_alive<0, 1>(), py::arg("value")); cls.def_static( "isinstance", [](PyValue &otherValue) -> bool { return DerivedTy::isaFunction(otherValue); }, - nb::arg("other_value")); + py::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); @@ -2275,11 +2249,11 @@ class PyBlockArgument : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyBlockArgument &self) { + c.def_property_readonly("owner", [](PyBlockArgument &self) { return PyBlock(self.getParentOperation(), mlirBlockArgumentGetOwner(self.get())); }); - c.def_prop_ro("arg_number", [](PyBlockArgument &self) { + c.def_property_readonly("arg_number", [](PyBlockArgument &self) { return mlirBlockArgumentGetArgNumber(self.get()); }); c.def( @@ -2287,7 +2261,7 @@ class PyBlockArgument : public PyConcreteValue { [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - nb::arg("type")); + py::arg("type")); } }; @@ -2299,14 +2273,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { + c.def_property_readonly("owner", [](PyOpResult &self) { assert( mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in the IR"); return self.getParentOperation().getObject(); }); - c.def_prop_ro("result_number", [](PyOpResult &self) { + c.def_property_readonly("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); } @@ -2343,7 +2317,7 @@ class PyBlockArgumentList operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyBlockArgumentList &self) { + c.def_property_readonly("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); }); } @@ -2448,10 +2422,10 @@ class PyOpResultList : public Sliceable { operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyOpResultList &self) { + c.def_property_readonly("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_prop_ro("owner", [](PyOpResultList &self) { + c.def_property_readonly("owner", [](PyOpResultList &self) { return self.operation->createOpView(); }); } @@ -2534,14 +2508,14 @@ class PyOpAttributeMap { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw nb::key_error("attempt to access a non-existent attribute"); + throw py::key_error("attempt to access a non-existent attribute"); } return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds attribute"); + throw py::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2560,7 +2534,7 @@ class PyOpAttributeMap { int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw nb::key_error("attempt to delete a non-existent attribute"); + throw py::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2572,8 +2546,8 @@ class PyOpAttributeMap { operation->get(), toMlirStringRef(name))); } - static void bind(nb::module_ &m) { - nb::class_(m, "OpAttributeMap") + static void bind(py::module &m) { + py::class_(m, "OpAttributeMap", py::module_local()) .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) @@ -2592,21 +2566,21 @@ class PyOpAttributeMap { // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRCore(nb::module_ &m) { +void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- - nb::enum_(m, "DiagnosticSeverity") + py::enum_(m, "DiagnosticSeverity", py::module_local()) .value("ERROR", MlirDiagnosticError) .value("WARNING", MlirDiagnosticWarning) .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); - nb::enum_(m, "WalkOrder") + py::enum_(m, "WalkOrder", py::module_local()) .value("PRE_ORDER", MlirWalkPreOrder) .value("POST_ORDER", MlirWalkPostOrder); - nb::enum_(m, "WalkResult") + py::enum_(m, "WalkResult", py::module_local()) .value("ADVANCE", MlirWalkResultAdvance) .value("INTERRUPT", MlirWalkResultInterrupt) .value("SKIP", MlirWalkResultSkip); @@ -2614,37 +2588,33 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- - nb::class_(m, "Diagnostic") - .def_prop_ro("severity", &PyDiagnostic::getSeverity) - .def_prop_ro("location", &PyDiagnostic::getLocation) - .def_prop_ro("message", &PyDiagnostic::getMessage) - .def_prop_ro("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> nb::str { + py::class_(m, "Diagnostic", py::module_local()) + .def_property_readonly("severity", &PyDiagnostic::getSeverity) + .def_property_readonly("location", &PyDiagnostic::getLocation) + .def_property_readonly("message", &PyDiagnostic::getMessage) + .def_property_readonly("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> py::str { if (!self.isValid()) - return nb::str(""); + return ""; return self.getMessage(); }); - nb::class_(m, "DiagnosticInfo") - .def("__init__", - [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { - new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); - }) - .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) - .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) - .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) + py::class_(m, "DiagnosticInfo", + py::module_local()) + .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) + .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) + .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) + .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) .def("__str__", [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); - nb::class_(m, "DiagnosticHandler") + py::class_(m, "DiagnosticHandler", py::module_local()) .def("detach", &PyDiagnosticHandler::detach) - .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) - .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) + .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) + .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) .def("__enter__", &PyDiagnosticHandler::contextEnter) - .def("__exit__", &PyDiagnosticHandler::contextExit, - nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()); + .def("__exit__", &PyDiagnosticHandler::contextExit); //---------------------------------------------------------------------------- // Mapping of MlirContext. @@ -2652,12 +2622,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- - nb::class_(m, "_BaseContext") - .def("__init__", - [](PyMlirContext &self) { - MlirContext context = mlirContextCreateWithThreading(false); - new (&self) PyMlirContext(context); - }) + py::class_(m, "_BaseContext", py::module_local()) + .def(py::init<>(&PyMlirContext::createNewContextForInit)) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { @@ -2669,28 +2635,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_clear_live_operations_inside", - nb::overload_cast( + py::overload_cast( &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) - .def_prop_ro_static( + .def("__exit__", &PyMlirContext::contextExit) + .def_property_readonly_static( "current", - [](nb::object & /*class*/) { + [](py::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return nb::none(); - return nb::cast(context); + return py::none().cast(); + return py::cast(context); }, "Gets the Context bound to the current thread or raises ValueError") - .def_prop_ro( + .def_property_readonly( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Gets a container for accessing dialects by name") - .def_prop_ro( + .def_property_readonly( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Alias for 'dialect'") .def( @@ -2699,14 +2665,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw nb::value_error( - (Twine("Dialect '") + name + "' not found").str().c_str()); + throw py::value_error( + (Twine("Dialect '") + name + "' not found").str()); } return PyDialectDescriptor(self.getRef(), dialect); }, - nb::arg("dialect_name"), + py::arg("dialect_name"), "Gets or loads a dialect by name, returning its descriptor object") - .def_prop_rw( + .def_property( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { return mlirContextGetAllowUnregisteredDialects(self.get()); @@ -2715,32 +2681,32 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, - nb::arg("callback"), + py::arg("callback"), "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - nb::arg("enable")) + py::arg("enable")) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - nb::arg("operation_name")) + py::arg("operation_name")) .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - nb::arg("registry")) - .def_prop_rw("emit_error_diagnostics", nullptr, - &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") + py::arg("registry")) + .def_property("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2748,12 +2714,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- - nb::class_(m, "DialectDescriptor") - .def_prop_ro("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = mlirDialectGetNamespace(self.get()); - return nb::str(ns.data, ns.length); - }) + py::class_(m, "DialectDescriptor", py::module_local()) + .def_property_readonly("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = + mlirDialectGetNamespace(self.get()); + return py::str(ns.data, ns.length); + }) .def("__repr__", [](PyDialectDescriptor &self) { MlirStringRef ns = mlirDialectGetNamespace(self.get()); std::string repr("(m, "Dialects") + py::class_(m, "Dialects", py::module_local()) .def("__getitem__", [=](PyDialects &self, std::string keyName) { MlirDialect dialect = self.getDialectForKey(keyName, /*attrError=*/false); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(keyName, std::move(descriptor)); }) .def("__getattr__", [=](PyDialects &self, std::string attrName) { MlirDialect dialect = self.getDialectForKey(attrName, /*attrError=*/true); - nb::object descriptor = - nb::cast(PyDialectDescriptor{self.getContext(), dialect}); + py::object descriptor = + py::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(attrName, std::move(descriptor)); }); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- - nb::class_(m, "Dialect") - .def(nb::init(), nb::arg("descriptor")) - .def_prop_ro("descriptor", - [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](nb::object self) { + py::class_(m, "Dialect", py::module_local()) + .def(py::init(), py::arg("descriptor")) + .def_property_readonly( + "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](py::object self) { auto clazz = self.attr("__class__"); - return nb::str(""); + return py::str(""); }); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- - nb::class_(m, "DialectRegistry") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) + py::class_(m, "DialectRegistry", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyDialectRegistry::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) - .def(nb::init<>()); + .def(py::init<>()); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- - nb::class_(m, "Location") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + py::class_(m, "Location", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), - nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def("__exit__", &PyLocation::contextExit) .def("__eq__", [](PyLocation &self, PyLocation &other) -> bool { return mlirLocationEqual(self, other); }) - .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) - .def_prop_ro_static( + .def("__eq__", [](PyLocation &self, py::object other) { return false; }) + .def_property_readonly_static( "current", - [](nb::object & /*class*/) { + [](py::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw nb::value_error("No current Location"); + throw py::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2834,14 +2801,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - nb::arg("context").none() = nb::none(), + py::arg("context") = py::none(), "Gets a Location representing an unknown location") .def_static( "callsite", [](PyLocation callee, const std::vector &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw nb::value_error("No caller frames provided"); + throw py::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -2849,8 +2816,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - nb::arg("callee"), nb::arg("frames"), - nb::arg("context").none() = nb::none(), + py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), kContextGetCallSiteLocationDocstring) .def_static( "file", @@ -2861,9 +2827,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, - nb::arg("filename"), nb::arg("line"), nb::arg("col"), - nb::arg("context").none() = nb::none(), - kContextGetFileLocationDocstring) + py::arg("filename"), py::arg("line"), py::arg("col"), + py::arg("context") = py::none(), kContextGetFileLocationDocstring) .def_static( "fused", [](const std::vector &pyLocations, @@ -2878,9 +2843,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - nb::arg("locations"), nb::arg("metadata").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kContextGetFusedLocationDocstring) + py::arg("locations"), py::arg("metadata") = py::none(), + py::arg("context") = py::none(), kContextGetFusedLocationDocstring) .def_static( "name", [](std::string name, std::optional childLoc, @@ -2892,22 +2856,21 @@ void mlir::python::populateIRCore(nb::module_ &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - nb::arg("name"), nb::arg("childLoc").none() = nb::none(), - nb::arg("context").none() = nb::none(), - kContextGetNameLocationDocString) + py::arg("name"), py::arg("childLoc") = py::none(), + py::arg("context") = py::none(), kContextGetNameLocationDocString) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - nb::arg("attribute"), nb::arg("context").none() = nb::none(), + py::arg("attribute"), py::arg("context") = py::none(), "Gets a Location from a LocationAttr") - .def_prop_ro( + .def_property_readonly( "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") - .def_prop_ro( + .def_property_readonly( "attr", [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") @@ -2916,7 +2879,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - nb::arg("message"), "Emits an error at this location") + py::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), @@ -2927,8 +2890,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- - nb::class_(m, "Module", nb::is_weak_referenceable()) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + py::class_(m, "Module", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", @@ -2940,19 +2903,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), - kModuleParseDocstring) - .def_static( - "parse", - [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirModule module = mlirModuleCreateParse( - context->get(), toMlirStringRef(moduleAsm)); - if (mlirModuleIsNull(module)) - throw MLIRError("Unable to parse module assembly", errors.take()); - return PyModule::forModule(module).releaseObject(); - }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + py::arg("asm"), py::arg("context") = py::none(), kModuleParseDocstring) .def_static( "create", @@ -2960,12 +2911,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirModule module = mlirModuleCreateEmpty(loc); return PyModule::forModule(module).releaseObject(); }, - nb::arg("loc").none() = nb::none(), "Creates an empty module") - .def_prop_ro( + py::arg("loc") = py::none(), "Creates an empty module") + .def_property_readonly( "context", [](PyModule &self) { return self.getContext().getObject(); }, "Context that created the Module") - .def_prop_ro( + .def_property_readonly( "operation", [](PyModule &self) { return PyOperation::forOperation(self.getContext(), @@ -2974,7 +2925,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .releaseObject(); }, "Accesses the module as an operation") - .def_prop_ro( + .def_property_readonly( "body", [](PyModule &self) { PyOperationRef moduleOp = PyOperation::forOperation( @@ -2992,7 +2943,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { kDumpDocstring) .def( "__str__", - [](nb::object self) { + [](py::object self) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, @@ -3001,26 +2952,27 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - nb::class_(m, "_OperationBase") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) + py::class_(m, "_OperationBase", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); }) .def("__eq__", - [](PyOperationBase &self, nb::object other) { return false; }) + [](PyOperationBase &self, py::object other) { return false; }) .def("__hash__", [](PyOperationBase &self) { return static_cast(llvm::hash_value(&self.getOperation())); }) - .def_prop_ro("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap(self.getOperation().getRef()); - }) - .def_prop_ro( + .def_property_readonly("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap( + self.getOperation().getRef()); + }) + .def_property_readonly( "context", [](PyOperationBase &self) { PyOperation &concreteOperation = self.getOperation(); @@ -3028,44 +2980,46 @@ void mlir::python::populateIRCore(nb::module_ &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_prop_ro("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = concreteOperation.get(); - MlirStringRef name = - mlirIdentifierStr(mlirOperationGetName(operation)); - return nb::str(name.data, name.length); - }) - .def_prop_ro("operands", - [](PyOperationBase &self) { - return PyOpOperandList(self.getOperation().getRef()); - }) - .def_prop_ro("regions", - [](PyOperationBase &self) { - return PyRegionList(self.getOperation().getRef()); - }) - .def_prop_ro( + .def_property_readonly("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = + concreteOperation.get(); + MlirStringRef name = mlirIdentifierStr( + mlirOperationGetName(operation)); + return py::str(name.data, name.length); + }) + .def_property_readonly("operands", + [](PyOperationBase &self) { + return PyOpOperandList( + self.getOperation().getRef()); + }) + .def_property_readonly("regions", + [](PyOperationBase &self) { + return PyRegionList( + self.getOperation().getRef()); + }) + .def_property_readonly( "results", [](PyOperationBase &self) { return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") - .def_prop_ro( + .def_property_readonly( "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw nb::value_error( + throw py::value_error( (Twine("Cannot call .result on operation ") + StringRef(name.data, name.length) + " which has " + Twine(numResults) + " results (it is only valid for operations with a " "single result)") - .str() - .c_str()); + .str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)) @@ -3073,7 +3027,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_prop_ro( + .def_property_readonly( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); @@ -3082,13 +3036,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Returns the source location the operation was defined or derived " "from.") - .def_prop_ro("parent", - [](PyOperationBase &self) -> nb::object { - auto parent = self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return nb::none(); - }) + .def_property_readonly("parent", + [](PyOperationBase &self) -> py::object { + auto parent = + self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return py::none(); + }) .def( "__str__", [](PyOperationBase &self) { @@ -3103,76 +3058,75 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, "Returns the assembly form of the operation.") .def("print", - nb::overload_cast( + py::overload_cast( &PyOperationBase::print), - nb::arg("state"), nb::arg("file").none() = nb::none(), - nb::arg("binary") = false, kOperationPrintStateDocstring) + py::arg("state"), py::arg("file") = py::none(), + py::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - nb::overload_cast, bool, bool, bool, bool, - bool, nb::object, bool, bool>( + py::overload_cast, bool, bool, bool, bool, + bool, py::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - nb::arg("large_elements_limit").none() = nb::none(), - nb::arg("enable_debug_info") = false, - nb::arg("pretty_debug_info") = false, - nb::arg("print_generic_op_form") = false, - nb::arg("use_local_scope") = false, - nb::arg("assume_verified") = false, - nb::arg("file").none() = nb::none(), nb::arg("binary") = false, - nb::arg("skip_regions") = false, kOperationPrintDocstring) - .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), - nb::arg("desired_version").none() = nb::none(), + py::arg("large_elements_limit") = py::none(), + py::arg("enable_debug_info") = false, + py::arg("pretty_debug_info") = false, + py::arg("print_generic_op_form") = false, + py::arg("use_local_scope") = false, + py::arg("assume_verified") = false, py::arg("file") = py::none(), + py::arg("binary") = false, py::arg("skip_regions") = false, + kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), + py::arg("desired_version") = py::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. - nb::arg("binary") = false, - nb::arg("large_elements_limit").none() = nb::none(), - nb::arg("enable_debug_info") = false, - nb::arg("pretty_debug_info") = false, - nb::arg("print_generic_op_form") = false, - nb::arg("use_local_scope") = false, - nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, + py::arg("binary") = false, + py::arg("large_elements_limit") = py::none(), + py::arg("enable_debug_info") = false, + py::arg("pretty_debug_info") = false, + py::arg("print_generic_op_form") = false, + py::arg("use_local_scope") = false, + py::arg("assume_verified") = false, py::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") - .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), + .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), "Puts self immediately after the other operation in its parent " "block.") - .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), + .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), "Puts self immediately before the other operation in its parent " "block.") .def( "clone", - [](PyOperationBase &self, nb::object ip) { + [](PyOperationBase &self, py::object ip) { return self.getOperation().clone(ip); }, - nb::arg("ip").none() = nb::none()) + py::arg("ip") = py::none()) .def( "detach_from_parent", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) - throw nb::value_error("Detached operation has no parent."); + throw py::value_error("Detached operation has no parent."); operation.detachFromParent(); return operation.createOpView(); }, "Detaches the operation from its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def("walk", &PyOperationBase::walk, nb::arg("callback"), - nb::arg("walk_order") = MlirWalkPostOrder); - - nb::class_(m, "Operation") - .def_static("create", &PyOperation::create, nb::arg("name"), - nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), - nb::arg("ip").none() = nb::none(), - nb::arg("infer_type") = false, kOperationCreateDocstring) + .def("walk", &PyOperationBase::walk, py::arg("callback"), + py::arg("walk_order") = MlirWalkPostOrder); + + py::class_(m, "Operation", py::module_local()) + .def_static("create", &PyOperation::create, py::arg("name"), + py::arg("results") = py::none(), + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = 0, + py::arg("loc") = py::none(), py::arg("ip") = py::none(), + py::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3180,15 +3134,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, - nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", - nb::arg("context").none() = nb::none(), + py::arg("source"), py::kw_only(), py::arg("source_name") = "", + py::arg("context") = py::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_prop_ro("operation", [](nb::object self) { return self; }) - .def_prop_ro("opview", &PyOperation::createOpView) - .def_prop_ro( + .def_property_readonly("operation", [](py::object self) { return self; }) + .def_property_readonly("opview", &PyOperation::createOpView) + .def_property_readonly( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); @@ -3196,33 +3151,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns the list of Operation successors."); auto opViewClass = - nb::class_(m, "OpView") - .def(nb::init(), nb::arg("operation")) - .def_prop_ro("operation", &PyOpView::getOperationObject) - .def_prop_ro("opview", [](nb::object self) { return self; }) + py::class_(m, "OpView", py::module_local()) + .def(py::init(), py::arg("operation")) + .def_property_readonly("operation", &PyOpView::getOperationObject) + .def_property_readonly("opview", [](py::object self) { return self; }) .def( "__str__", - [](PyOpView &self) { return nb::str(self.getOperationObject()); }) - .def_prop_ro( + [](PyOpView &self) { return py::str(self.getOperationObject()); }) + .def_property_readonly( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, "Returns the list of Operation successors."); - opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); - opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); - opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); + opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, nb::arg("cls"), - nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), + py::arg("operands") = py::none(), py::arg("attributes") = py::none(), + py::arg("successors") = py::none(), py::arg("regions") = py::none(), + py::arg("loc") = py::none(), py::arg("ip") = py::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( - [](const nb::object &cls, const std::string &sourceStr, + [](const py::object &cls, const std::string &sourceStr, const std::string &sourceName, DefaultingPyMlirContext context) { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3233,30 +3185,30 @@ void mlir::python::populateIRCore(nb::module_ &m) { // `OpView` subclasses, and is not intended to be used on `OpView` // directly. std::string clsOpName = - nb::cast(cls.attr("OPERATION_NAME")); + py::cast(cls.attr("OPERATION_NAME")); MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); std::string_view parsedOpName(identifier.data, identifier.length); if (clsOpName != parsedOpName) throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + parsedOpName + "'"); - return PyOpView::constructDerived(cls, parsed.getObject()); + return PyOpView::constructDerived(cls, *parsed.get()); }, - nb::arg("cls"), nb::arg("source"), nb::kw_only(), - nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), + py::arg("cls"), py::arg("source"), py::kw_only(), + py::arg("source_name") = "", py::arg("context") = py::none(), "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. //---------------------------------------------------------------------------- - nb::class_(m, "Region") - .def_prop_ro( + py::class_(m, "Region", py::module_local()) + .def_property_readonly( "blocks", [](PyRegion &self) { return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") - .def_prop_ro( + .def_property_readonly( "owner", [](PyRegion &self) { return self.getParentOperation()->createOpView(); @@ -3274,27 +3226,27 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyRegion &self, PyRegion &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); + .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- - nb::class_(m, "Block") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) - .def_prop_ro( + py::class_(m, "Block", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_property_readonly( "owner", [](PyBlock &self) { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") - .def_prop_ro( + .def_property_readonly( "region", [](PyBlock &self) { MlirRegion region = mlirBlockGetParentRegion(self.get()); return PyRegion(self.getParentOperation(), region); }, "Returns the owning region of this block.") - .def_prop_ro( + .def_property_readonly( "arguments", [](PyBlock &self) { return PyBlockArgumentList(self.getParentOperation(), self.get()); @@ -3313,7 +3265,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { return mlirBlockEraseArgument(self.get(), index); }, "Erase the argument at 'index' and remove it from the argument list.") - .def_prop_ro( + .def_property_readonly( "operations", [](PyBlock &self) { return PyOperationList(self.getParentOperation(), self.get()); @@ -3321,15 +3273,15 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, const nb::sequence &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyRegion &parent, const py::list &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - nb::arg("parent"), nb::arg("arg_types") = nb::list(), - nb::arg("arg_locs") = std::nullopt, + py::arg("parent"), py::arg("arg_types") = py::list(), + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " "region (with given argument types and locations).") .def( @@ -3343,32 +3295,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, const nb::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = - createBlock(nb::cast(pyArgTypes), pyArgLocs); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - nb::arg("arg_types"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, const nb::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const py::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = - createBlock(nb::cast(pyArgTypes), pyArgLocs); + MlirBlock block = createBlock(pyArgTypes, pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - nb::arg("arg_types"), nb::kw_only(), - nb::arg("arg_locs") = std::nullopt, + py::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " "(with given argument types and locations).") .def( @@ -3385,7 +3333,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyBlock &self, PyBlock &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) + .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) .def("__hash__", [](PyBlock &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3411,7 +3359,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, - nb::arg("operation"), + py::arg("operation"), "Appends an operation to this block. If the operation is currently " "in another block, it will be moved."); @@ -3419,41 +3367,39 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- - nb::class_(m, "InsertionPoint") - .def(nb::init(), nb::arg("block"), + py::class_(m, "InsertionPoint", py::module_local()) + .def(py::init(), py::arg("block"), "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) - .def("__exit__", &PyInsertionPoint::contextExit, - nb::arg("exc_type").none(), nb::arg("exc_value").none(), - nb::arg("traceback").none()) - .def_prop_ro_static( + .def("__exit__", &PyInsertionPoint::contextExit) + .def_property_readonly_static( "current", - [](nb::object & /*class*/) { + [](py::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw nb::value_error("No current InsertionPoint"); + throw py::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(nb::init(), nb::arg("beforeOperation"), + .def(py::init(), py::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - nb::arg("block"), "Inserts at the beginning of the block.") + py::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - nb::arg("block"), "Inserts before the block terminator.") - .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), + py::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, py::arg("operation"), "Inserts an operation.") - .def_prop_ro( + .def_property_readonly( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, "Returns the block that this InsertionPoint points to.") - .def_prop_ro( + .def_property_readonly( "ref_operation", - [](PyInsertionPoint &self) -> nb::object { + [](PyInsertionPoint &self) -> py::object { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return nb::none(); + return py::none(); }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3462,12 +3408,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - nb::class_(m, "Attribute") + py::class_(m, "Attribute", py::module_local()) // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. - .def(nb::init(), nb::arg("cast_from_type"), + .def(py::init(), py::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", @@ -3479,24 +3426,24 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse attribute", errors.take()); return attr; }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + py::arg("asm"), py::arg("context") = py::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " "failure.") - .def_prop_ro( + .def_property_readonly( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_prop_ro("type", - [](PyAttribute &self) { return mlirAttributeGetType(self); }) + .def_property_readonly( + "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - nb::keep_alive<0, 1>(), "Binds a name to the attribute") + py::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) + .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) .def("__hash__", [](PyAttribute &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3527,35 +3474,36 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_prop_ro("typeid", - [](PyAttribute &self) -> MlirTypeID { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return mlirTypeID; - }) + .def_property_readonly( + "typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirAttributeGetDialect(self)); if (!typeCaster) - return nb::cast(self); + return py::cast(self); return typeCaster.value()(self); }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- - nb::class_(m, "NamedAttribute") + py::class_(m, "NamedAttribute", py::module_local()) .def("__repr__", [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - nb::str(mlirIdentifierStr(self.namedAttr.name).data, + py::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, @@ -3564,28 +3512,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_prop_ro( + .def_property_readonly( "name", [](PyNamedAttribute &self) { - return nb::str(mlirIdentifierStr(self.namedAttr.name).data, + return py::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length); }, "The name of the NamedAttribute binding") - .def_prop_ro( + .def_property_readonly( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, - nb::keep_alive<0, 1>(), + py::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - nb::class_(m, "Type") + py::class_(m, "Type", py::module_local()) // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. - .def(nb::init(), nb::arg("cast_from_type"), + .def(py::init(), py::arg("cast_from_type"), "Casts the passed type to the generic Type") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", @@ -3597,15 +3545,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { throw MLIRError("Unable to parse type", errors.take()); return type; }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), + py::arg("asm"), py::arg("context") = py::none(), kContextParseTypeDocstring) - .def_prop_ro( + .def_property_readonly( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) - .def( - "__eq__", [](PyType &self, nb::object &other) { return false; }, - nb::arg("other").none()) + .def("__eq__", [](PyType &self, py::object &other) { return false; }) .def("__hash__", [](PyType &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3639,27 +3585,28 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirTypeGetDialect(self)); if (!typeCaster) - return nb::cast(self); + return py::cast(self); return typeCaster.value()(self); }) - .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { + .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) return mlirTypeID; - auto origRepr = nb::cast(nb::repr(nb::cast(self))); - throw nb::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); + auto origRepr = + pybind11::repr(pybind11::cast(self)).cast(); + throw py::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str()); }); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- - nb::class_(m, "TypeID") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + py::class_(m, "TypeID", py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether @@ -3667,7 +3614,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("__eq__", [](PyTypeID &self, PyTypeID &other) { return self == other; }) .def("__eq__", - [](PyTypeID &self, const nb::object &other) { return false; }) + [](PyTypeID &self, const py::object &other) { return false; }) // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. @@ -3678,20 +3625,20 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - nb::class_(m, "Value") - .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + py::class_(m, "Value", py::module_local()) + .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) - .def_prop_ro( + .def_property_readonly( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, "Context in which the value lives.") .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) - .def_prop_ro( + .def_property_readonly( "owner", - [](PyValue &self) -> nb::object { + [](PyValue &self) -> py::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert( @@ -3704,22 +3651,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { if (mlirValueIsABlockArgument(v)) { MlirBlock block = mlirBlockArgumentGetOwner(self.get()); - return nb::cast(PyBlock(self.getParentOperation(), block)); + return py::cast(PyBlock(self.getParentOperation(), block)); } assert(false && "Value must be a block argument or an op result"); - return nb::none(); + return py::none(); }) - .def_prop_ro("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) + .def_property_readonly("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyValue &self, nb::object other) { return false; }) + .def("__eq__", [](PyValue &self, py::object other) { return false; }) .def("__hash__", [](PyValue &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3751,26 +3698,26 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - nb::arg("use_local_scope") = false) + py::arg("use_local_scope") = false) .def( "get_name", - [](PyValue &self, PyAsmState &state) { + [](PyValue &self, std::reference_wrapper state) { PyPrintAccumulator printAccum; - MlirAsmState valueState = state.get(); + MlirAsmState valueState = state.get().get(); mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, - nb::arg("state"), kGetNameAsOperand) - .def_prop_ro("type", - [](PyValue &self) { return mlirValueGetType(self.get()); }) + py::arg("state"), kGetNameAsOperand) + .def_property_readonly( + "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "set_type", [](PyValue &self, const PyType &type) { return mlirValueSetType(self.get(), type); }, - nb::arg("type")) + py::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { @@ -3783,22 +3730,22 @@ void mlir::python::populateIRCore(nb::module_ &m) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - nb::arg("with"), nb::arg("exceptions"), + py::arg("with"), py::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, nb::list exceptions) { + [](MlirValue self, MlirValue with, py::list exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector exceptionOps; - for (nb::handle exception : exceptions) { - exceptionOps.push_back(nb::cast(exception).get()); + for (py::handle exception : exceptions) { + exceptionOps.push_back(exception.cast().get()); } mlirValueReplaceAllUsesExcept( self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - nb::arg("with"), nb::arg("exceptions"), + py::arg("with"), py::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) { return self.maybeDownCast(); }); @@ -3806,20 +3753,20 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyOpResult::bind(m); PyOpOperand::bind(m); - nb::class_(m, "AsmState") - .def(nb::init(), nb::arg("value"), - nb::arg("use_local_scope") = false) - .def(nb::init(), nb::arg("op"), - nb::arg("use_local_scope") = false); + py::class_(m, "AsmState", py::module_local()) + .def(py::init(), py::arg("value"), + py::arg("use_local_scope") = false) + .def(py::init(), py::arg("op"), + py::arg("use_local_scope") = false); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- - nb::class_(m, "SymbolTable") - .def(nb::init()) + py::class_(m, "SymbolTable", py::module_local()) + .def(py::init()) .def("__getitem__", &PySymbolTable::dunderGetItem) - .def("insert", &PySymbolTable::insert, nb::arg("operation")) - .def("erase", &PySymbolTable::erase, nb::arg("operation")) + .def("insert", &PySymbolTable::insert, py::arg("operation")) + .def("erase", &PySymbolTable::erase, py::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) .def("__contains__", [](PySymbolTable &table, const std::string &name) { @@ -3828,19 +3775,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - nb::arg("symbol"), nb::arg("name")) + py::arg("symbol"), py::arg("name")) .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - nb::arg("symbol")) + py::arg("symbol")) .def_static("get_visibility", &PySymbolTable::getVisibility, - nb::arg("symbol")) + py::arg("symbol")) .def_static("set_visibility", &PySymbolTable::setVisibility, - nb::arg("symbol"), nb::arg("visibility")) + py::arg("symbol"), py::arg("visibility")) .def_static("replace_all_symbol_uses", - &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), - nb::arg("new_symbol"), nb::arg("from_op")) + &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), + py::arg("new_symbol"), py::arg("from_op")) .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, - nb::arg("from_op"), nb::arg("all_sym_uses_visible"), - nb::arg("callback")); + py::arg("from_op"), py::arg("all_sym_uses_visible"), + py::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); @@ -3862,15 +3809,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); - nb::register_exception_translator([](const std::exception_ptr &p, - void *payload) { + py::register_local_exception_translator([](std::exception_ptr p) { // We can't define exceptions with custom fields through pybind, so instead // the exception class is defined in python and imported here. try { if (p) std::rethrow_exception(p); } catch (const MLIRError &e) { - nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) + py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("MLIRError")(e.message, e.errorDiagnostics); PyErr_SetObject(PyExc_Exception, obj.ptr()); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index c339a93e3..54cfa5606 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include - #include #include +#include +#include +#include +#include #include #include #include @@ -24,7 +24,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -namespace nb = nanobind; +namespace py = pybind11; namespace mlir { namespace python { @@ -53,10 +53,10 @@ namespace { /// Takes in an optional ist of operands and converts them into a SmallVector /// of MlirVlaues. Returns an empty SmallVector if the list is empty. -llvm::SmallVector wrapOperands(std::optional operandList) { +llvm::SmallVector wrapOperands(std::optional operandList) { llvm::SmallVector mlirOperands; - if (!operandList || operandList->size() == 0) { + if (!operandList || operandList->empty()) { return mlirOperands; } @@ -68,42 +68,40 @@ llvm::SmallVector wrapOperands(std::optional operandList) { PyValue *val; try { - val = nb::cast(it.value()); + val = py::cast(it.value()); if (!val) - throw nb::cast_error(); + throw py::cast_error(); mlirOperands.push_back(val->get()); continue; - } catch (nb::cast_error &err) { + } catch (py::cast_error &err) { // Intentionally unhandled to try sequence below first. (void)err; } try { - auto vals = nb::cast(it.value()); - for (nb::handle v : vals) { + auto vals = py::cast(it.value()); + for (py::object v : vals) { try { - val = nb::cast(v); + val = py::cast(v); if (!val) - throw nb::cast_error(); + throw py::cast_error(); mlirOperands.push_back(val->get()); - } catch (nb::cast_error &err) { - throw nb::value_error( + } catch (py::cast_error &err) { + throw py::value_error( (llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str() - .c_str()); + .str()); } } continue; - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + } catch (py::cast_error &err) { + throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str() - .c_str()); + .str()); } - throw nb::cast_error(); + throw py::cast_error(); } return mlirOperands; @@ -146,24 +144,24 @@ wrapRegions(std::optional> regions) { template class PyConcreteOpInterface { protected: - using ClassTy = nb::class_; + using ClassTy = py::class_; using GetTypeIDFunctionTy = MlirTypeID (*)(); public: /// Constructs an interface instance from an object that is either an /// operation or a subclass of OpView. In the latter case, only the static /// methods of the interface are accessible to the caller. - PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) + PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) : obj(std::move(object)) { try { - operation = &nb::cast(obj); - } catch (nb::cast_error &) { + operation = &py::cast(obj); + } catch (py::cast_error &) { // Do nothing. } try { - operation = &nb::cast(obj).getOperation(); - } catch (nb::cast_error &) { + operation = &py::cast(obj).getOperation(); + } catch (py::cast_error &) { // Do nothing. } @@ -171,7 +169,7 @@ class PyConcreteOpInterface { if (!mlirOperationImplementsInterface(*operation, ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); + throw py::value_error(msg + ConcreteIface::pyClassName); } MlirIdentifier identifier = mlirOperationGetName(*operation); @@ -179,9 +177,9 @@ class PyConcreteOpInterface { opName = std::string(stringRef.data, stringRef.length); } else { try { - opName = nb::cast(obj.attr("OPERATION_NAME")); - } catch (nb::cast_error &) { - throw nb::type_error( + opName = obj.attr("OPERATION_NAME").template cast(); + } catch (py::cast_error &) { + throw py::type_error( "Op interface does not refer to an operation or OpView class"); } @@ -189,19 +187,22 @@ class PyConcreteOpInterface { mlirStringRefCreate(opName.data(), opName.length()), context.resolve().get(), ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); + throw py::value_error(msg + ConcreteIface::pyClassName); } } } /// Creates the Python bindings for this class in the given module. - static void bind(nb::module_ &m) { - nb::class_ cls(m, ConcreteIface::pyClassName); - cls.def(nb::init(), nb::arg("object"), - nb::arg("context").none() = nb::none(), constructorDoc) - .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, - operationDoc) - .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); + static void bind(py::module &m) { + py::class_ cls(m, ConcreteIface::pyClassName, + py::module_local()); + cls.def(py::init(), py::arg("object"), + py::arg("context") = py::none(), constructorDoc) + .def_property_readonly("operation", + &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, + opviewDoc); ConcreteIface::bindDerived(cls); } @@ -215,9 +216,9 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - nb::object getOperationObject() { + py::object getOperationObject() { if (operation == nullptr) { - throw nb::type_error("Cannot get an operation from a static interface"); + throw py::type_error("Cannot get an operation from a static interface"); } return operation->getRef().releaseObject(); @@ -226,9 +227,9 @@ class PyConcreteOpInterface { /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - nb::object getOpView() { + py::object getOpView() { if (operation == nullptr) { - throw nb::type_error("Cannot get an opview from a static interface"); + throw py::type_error("Cannot get an opview from a static interface"); } return operation->createOpView(); @@ -241,7 +242,7 @@ class PyConcreteOpInterface { private: PyOperation *operation = nullptr; std::string opName; - nb::object obj; + py::object obj; }; /// Python wrapper for InferTypeOpInterface. This interface has only static @@ -275,7 +276,7 @@ class PyInferTypeOpInterface /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional operandList, + inferReturnTypes(std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, @@ -298,7 +299,7 @@ class PyInferTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw nb::value_error("Failed to infer result types"); + throw py::value_error("Failed to infer result types"); } return inferredTypes; @@ -306,12 +307,11 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("properties").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("context").none() = nb::none(), - nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), + py::arg("properties") = py::none(), py::arg("regions") = py::none(), + py::arg("context") = py::none(), py::arg("loc") = py::none(), + inferReturnTypesDoc); } }; @@ -319,9 +319,9 @@ class PyInferTypeOpInterface class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} - PyShapedTypeComponents(nb::list shape, MlirType elementType) + PyShapedTypeComponents(py::list shape, MlirType elementType) : shape(std::move(shape)), elementType(elementType), ranked(true) {} - PyShapedTypeComponents(nb::list shape, MlirType elementType, + PyShapedTypeComponents(py::list shape, MlirType elementType, MlirAttribute attribute) : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} @@ -330,9 +330,10 @@ class PyShapedTypeComponents { : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} - static void bind(nb::module_ &m) { - nb::class_(m, "ShapedTypeComponents") - .def_prop_ro( + static void bind(py::module &m) { + py::class_(m, "ShapedTypeComponents", + py::module_local()) + .def_property_readonly( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") @@ -341,57 +342,57 @@ class PyShapedTypeComponents { [](PyType &elementType) { return PyShapedTypeComponents(elementType); }, - nb::arg("element_type"), + py::arg("element_type"), "Create an shaped type components object with only the element " "type.") .def_static( "get", - [](nb::list shape, PyType &elementType) { + [](py::list shape, PyType &elementType) { return PyShapedTypeComponents(std::move(shape), elementType); }, - nb::arg("shape"), nb::arg("element_type"), + py::arg("shape"), py::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", - [](nb::list shape, PyType &elementType, PyAttribute &attribute) { + [](py::list shape, PyType &elementType, PyAttribute &attribute) { return PyShapedTypeComponents(std::move(shape), elementType, attribute); }, - nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), + py::arg("shape"), py::arg("element_type"), py::arg("attribute"), "Create a ranked shaped type components object with attribute.") - .def_prop_ro( + .def_property_readonly( "has_rank", [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, "Returns whether the given shaped type component is ranked.") - .def_prop_ro( + .def_property_readonly( "rank", - [](PyShapedTypeComponents &self) -> nb::object { + [](PyShapedTypeComponents &self) -> py::object { if (!self.ranked) { - return nb::none(); + return py::none(); } - return nb::int_(self.shape.size()); + return py::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " "the shaped type components does not have a rank, None is " "returned.") - .def_prop_ro( + .def_property_readonly( "shape", - [](PyShapedTypeComponents &self) -> nb::object { + [](PyShapedTypeComponents &self) -> py::object { if (!self.ranked) { - return nb::none(); + return py::none(); } - return nb::list(self.shape); + return py::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " "of integers. Returns none if the shaped type component does not " "have a rank."); } - nb::object getCapsule(); - static PyShapedTypeComponents createFromCapsule(nb::object capsule); + pybind11::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); private: - nb::list shape; + py::list shape; MlirType elementType; MlirAttribute attribute; bool ranked{false}; @@ -423,7 +424,7 @@ class PyInferShapedTypeOpInterface if (!hasRank) { data->inferredShapedTypeComponents.emplace_back(elementType); } else { - nb::list shapeList; + py::list shapeList; for (intptr_t i = 0; i < rank; ++i) { shapeList.append(shape[i]); } @@ -435,7 +436,7 @@ class PyInferShapedTypeOpInterface /// Given the arguments required to build an operation, attempts to infer the /// shaped type components. Throws value_error on failure. std::vector inferReturnTypeComponents( - std::optional operandList, + std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { @@ -457,7 +458,7 @@ class PyInferShapedTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw nb::value_error("Failed to infer result shape type components"); + throw py::value_error("Failed to infer result shape type components"); } return inferredShapedTypeComponents; @@ -466,16 +467,14 @@ class PyInferShapedTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("regions").none() = nb::none(), - nb::arg("properties").none() = nb::none(), - nb::arg("context").none() = nb::none(), - nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); + py::arg("operands") = py::none(), + py::arg("attributes") = py::none(), py::arg("regions") = py::none(), + py::arg("properties") = py::none(), py::arg("context") = py::none(), + py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); } }; -void populateIRInterfaces(nb::module_ &m) { +void populateIRInterfaces(py::module &m) { PyInferTypeOpInterface::bind(m); PyShapedTypeComponents::bind(m); PyInferShapedTypeOpInterface::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 416a14218..6727860c0 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -7,19 +7,16 @@ //===----------------------------------------------------------------------===// #include "IRModule.h" +#include "Globals.h" +#include "PybindUtils.h" -#include -#include +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" #include #include -#include "Globals.h" -#include "NanobindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Support.h" - -namespace nb = nanobind; +namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -44,14 +41,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; - nb::object loaded = nb::none(); + py::object loaded = py::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { - loaded = nb::module_::import_(moduleName.c_str()); - } catch (nb::python_error &e) { + loaded = py::module::import(moduleName.c_str()); + } catch (py::error_already_set &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; } @@ -69,39 +66,41 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - nb::callable pyFunc, bool replace) { - nb::object &found = attributeBuilderMap[attributeKind]; + py::function pyFunc, bool replace) { + py::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + - nb::cast(nb::str(found))) + py::str(found).operator std::string()) .str()); } found = std::move(pyFunc); } void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, - nb::callable typeCaster, bool replace) { - nb::object &found = typeCasterMap[mlirTypeID]; + pybind11::function typeCaster, + bool replace) { + pybind11::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + - nb::cast(nb::str(found))); + py::str(found).operator std::string()); found = std::move(typeCaster); } void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, - nb::callable valueCaster, bool replace) { - nb::object &found = valueCasterMap[mlirTypeID]; + pybind11::function valueCaster, + bool replace) { + pybind11::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + - nb::cast(nb::repr(found))); + py::repr(found).cast()); found = std::move(valueCaster); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - nb::object pyClass) { - nb::object &found = dialectClassMap[dialectNamespace]; + py::object pyClass) { + py::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + dialectNamespace + "' is already registered.") @@ -111,8 +110,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - nb::object pyClass, bool replace) { - nb::object &found = operationClassMap[operationName]; + py::object pyClass, bool replace) { + py::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") @@ -121,7 +120,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, found = std::move(pyClass); } -std::optional +std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { @@ -131,7 +130,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } -std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -143,7 +142,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -155,7 +154,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional +std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) @@ -169,7 +168,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { return std::nullopt; } -std::optional +std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Make sure dialect module is loaded. auto split = operationName.split('.'); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index a242ff26b..172898cfd 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -10,22 +10,20 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H -#include -#include - #include #include #include #include "Globals.h" -#include "NanobindUtils.h" +#include "PybindUtils.h" + #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/PybindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -51,7 +49,7 @@ class PyValue; template class PyObjectRef { public: - PyObjectRef(T *referrent, nanobind::object object) + PyObjectRef(T *referrent, pybind11::object object) : referrent(referrent), object(std::move(object)) { assert(this->referrent && "cannot construct PyObjectRef with null referrent"); @@ -69,13 +67,13 @@ class PyObjectRef { int getRefCount() { if (!object) return 0; - return Py_REFCNT(object.ptr()); + return object.ref_count(); } /// Releases the object held by this instance, returning it. /// This is the proper thing to return from a function that wants to return /// the reference. Note that this does not work from initializers. - nanobind::object releaseObject() { + pybind11::object releaseObject() { assert(referrent && object); referrent = nullptr; auto stolen = std::move(object); @@ -87,7 +85,7 @@ class PyObjectRef { assert(referrent && object); return referrent; } - nanobind::object getObject() { + pybind11::object getObject() { assert(referrent && object); return object; } @@ -95,7 +93,7 @@ class PyObjectRef { private: T *referrent; - nanobind::object object; + pybind11::object object; }; /// Tracks an entry in the thread context stack. New entries are pushed onto @@ -114,9 +112,9 @@ class PyThreadContextEntry { Location, }; - PyThreadContextEntry(FrameKind frameKind, nanobind::object context, - nanobind::object insertionPoint, - nanobind::object location) + PyThreadContextEntry(FrameKind frameKind, pybind11::object context, + pybind11::object insertionPoint, + pybind11::object location) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), location(std::move(location)), frameKind(frameKind) {} @@ -136,26 +134,26 @@ class PyThreadContextEntry { /// Stack management. static PyThreadContextEntry *getTopOfStack(); - static nanobind::object pushContext(nanobind::object context); + static pybind11::object pushContext(PyMlirContext &context); static void popContext(PyMlirContext &context); - static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); + static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); static void popInsertionPoint(PyInsertionPoint &insertionPoint); - static nanobind::object pushLocation(nanobind::object location); + static pybind11::object pushLocation(PyLocation &location); static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: - static void push(FrameKind frameKind, nanobind::object context, - nanobind::object insertionPoint, nanobind::object location); + static void push(FrameKind frameKind, pybind11::object context, + pybind11::object insertionPoint, pybind11::object location); /// An object reference to the PyContext. - nanobind::object context; + pybind11::object context; /// An object reference to the current insertion point. - nanobind::object insertionPoint; + pybind11::object insertionPoint; /// An object reference to the current location. - nanobind::object location; + pybind11::object location; // The kind of push that was performed. FrameKind frameKind; }; @@ -165,15 +163,14 @@ using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; - PyMlirContext(MlirContext context); PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (nanobind::init) method, pybind11 is - /// quite strict about needing to return a pointer that is not yet associated - /// to an nanobind::object. Since the forContext() method acts like a pool, - /// possibly returning a recycled context, it does not satisfy this need. The - /// usual way in python to accomplish such a thing is to override __new__, but + /// For the case of a python __init__ (py::init) method, pybind11 is quite + /// strict about needing to return a pointer that is not yet associated to + /// an py::object. Since the forContext() method acts like a pool, possibly + /// returning a recycled context, it does not satisfy this need. The usual + /// way in python to accomplish such a thing is to override __new__, but /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). @@ -190,17 +187,17 @@ class PyMlirContext { /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { - return PyMlirContextRef(this, nanobind::cast(this)); + return PyMlirContextRef(this, pybind11::cast(this)); } /// Gets a capsule wrapping the void* within the MlirContext. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. - static nanobind::object createFromCapsule(nanobind::object capsule); + static pybind11::object createFromCapsule(pybind11::object capsule); /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -240,14 +237,14 @@ class PyMlirContext { size_t getLiveModuleCount(); /// Enter and exit the context manager. - static nanobind::object contextEnter(nanobind::object context); - void contextExit(const nanobind::object &excType, - const nanobind::object &excVal, - const nanobind::object &excTb); + pybind11::object contextEnter(); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); /// Attaches a Python callback as a diagnostic handler, returning a /// registration object (internally a PyDiagnosticHandler). - nanobind::object attachDiagnosticHandler(nanobind::object callback); + pybind11::object attachDiagnosticHandler(pybind11::object callback); /// Controls whether error diagnostics should be propagated to diagnostic /// handlers, instead of being captured by `ErrorCapture`. @@ -255,6 +252,8 @@ class PyMlirContext { struct ErrorCapture; private: + PyMlirContext(MlirContext context); + // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an @@ -269,7 +268,7 @@ class PyMlirContext { // from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveModuleMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveModuleMap liveModules; // Interns all live operations associated with this context. Operations @@ -277,7 +276,7 @@ class PyMlirContext { // removed from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveOperationMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveOperationMap liveOperations; bool emitErrorDiagnostics = false; @@ -325,19 +324,19 @@ class PyLocation : public BaseContextObject { MlirLocation get() const { return loc; } /// Enter and exit the context manager. - static nanobind::object contextEnter(nanobind::object location); - void contextExit(const nanobind::object &excType, - const nanobind::object &excVal, - const nanobind::object &excTb); + pybind11::object contextEnter(); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); /// Gets a capsule wrapping the void* within the MlirLocation. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyLocation from the MlirLocation wrapped by a capsule. /// Note that PyLocation instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirLocation /// is taken by calling this function. - static PyLocation createFromCapsule(nanobind::object capsule); + static PyLocation createFromCapsule(pybind11::object capsule); private: MlirLocation loc; @@ -354,8 +353,8 @@ class PyDiagnostic { bool isValid() { return valid; } MlirDiagnosticSeverity getSeverity(); PyLocation getLocation(); - nanobind::str getMessage(); - nanobind::tuple getNotes(); + pybind11::str getMessage(); + pybind11::tuple getNotes(); /// Materialized diagnostic information. This is safe to access outside the /// diagnostic callback. @@ -374,7 +373,7 @@ class PyDiagnostic { /// If notes have been materialized from the diagnostic, then this will /// be populated with the corresponding objects (all castable to /// PyDiagnostic). - std::optional materializedNotes; + std::optional materializedNotes; bool valid = true; }; @@ -399,7 +398,7 @@ class PyDiagnostic { /// is no way to attach an existing handler object). class PyDiagnosticHandler { public: - PyDiagnosticHandler(MlirContext context, nanobind::object callback); + PyDiagnosticHandler(MlirContext context, pybind11::object callback); ~PyDiagnosticHandler(); bool isAttached() { return registeredID.has_value(); } @@ -408,16 +407,16 @@ class PyDiagnosticHandler { /// Detaches the handler. Does nothing if not attached. void detach(); - nanobind::object contextEnter() { return nanobind::cast(this); } - void contextExit(const nanobind::object &excType, - const nanobind::object &excVal, - const nanobind::object &excTb) { + pybind11::object contextEnter() { return pybind11::cast(this); } + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb) { detach(); } private: MlirContext context; - nanobind::object callback; + pybind11::object callback; std::optional registeredID; bool hadError = false; friend class PyMlirContext; @@ -478,12 +477,12 @@ class PyDialects : public BaseContextObject { /// objects of this type will be returned directly. class PyDialect { public: - PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} + PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} - nanobind::object getDescriptor() { return descriptor; } + pybind11::object getDescriptor() { return descriptor; } private: - nanobind::object descriptor; + pybind11::object descriptor; }; /// Wrapper around an MlirDialectRegistry. @@ -506,8 +505,8 @@ class PyDialectRegistry { operator MlirDialectRegistry() const { return registry; } MlirDialectRegistry get() const { return registry; } - nanobind::object getCapsule(); - static PyDialectRegistry createFromCapsule(nanobind::object capsule); + pybind11::object getCapsule(); + static PyDialectRegistry createFromCapsule(pybind11::object capsule); private: MlirDialectRegistry registry; @@ -543,25 +542,26 @@ class PyModule : public BaseContextObject { /// Gets a strong reference to this module. PyModuleRef getRef() { - return PyModuleRef(this, nanobind::borrow(handle)); + return PyModuleRef(this, + pybind11::reinterpret_borrow(handle)); } /// Gets a capsule wrapping the void* within the MlirModule. /// Note that the module does not (yet) provide a corresponding factory for /// constructing from a capsule as that would require uniquing PyModule /// instances, which is not currently done. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. /// Note that PyModule instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirModule /// is taken by calling this function. - static nanobind::object createFromCapsule(nanobind::object capsule); + static pybind11::object createFromCapsule(pybind11::object capsule); private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; - nanobind::handle handle; + pybind11::handle handle; }; class PyAsmState; @@ -574,18 +574,18 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nanobind::object fileObject, bool binary, + bool assumeVerified, py::object fileObject, bool binary, bool skipRegions); - void print(PyAsmState &state, nanobind::object fileObject, bool binary); + void print(PyAsmState &state, py::object fileObject, bool binary); - nanobind::object getAsm(bool binary, + pybind11::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. - void writeBytecode(const nanobind::object &fileObject, + void writeBytecode(const pybind11::object &fileObject, std::optional bytecodeVersion); // Implement the walk method. @@ -621,13 +621,13 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, - nanobind::object parentKeepAlive = nanobind::object()); + pybind11::object parentKeepAlive = pybind11::object()); /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, - nanobind::object parentKeepAlive = nanobind::object()); + pybind11::object parentKeepAlive = pybind11::object()); /// Parses a source string (either text assembly or bytecode), creating a /// detached operation. @@ -640,7 +640,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void detachFromParent() { mlirOperationRemoveFromParent(getOperation()); setDetached(); - parentKeepAlive = nanobind::object(); + parentKeepAlive = pybind11::object(); } /// Gets the backing operation. @@ -651,11 +651,12 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } PyOperationRef getRef() { - return PyOperationRef(this, nanobind::borrow(handle)); + return PyOperationRef( + this, pybind11::reinterpret_borrow(handle)); } bool isAttached() { return attached; } - void setAttached(const nanobind::object &parent = nanobind::object()) { + void setAttached(const pybind11::object &parent = pybind11::object()) { assert(!attached && "operation already attached"); attached = true; } @@ -674,24 +675,24 @@ class PyOperation : public PyOperationBase, public BaseContextObject { std::optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static nanobind::object createFromCapsule(nanobind::object capsule); + static pybind11::object createFromCapsule(pybind11::object capsule); /// Creates an operation. See corresponding python docstring. - static nanobind::object + static pybind11::object create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const nanobind::object &ip, + DefaultingPyLocation location, const pybind11::object &ip, bool inferType); /// Creates an OpView suitable for this operation. - nanobind::object createOpView(); + pybind11::object createOpView(); /// Erases the underlying MlirOperation, removes its pointer from the /// parent context's live operations map, and sets the valid bit false. @@ -701,23 +702,23 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void setInvalid() { valid = false; } /// Clones this operation. - nanobind::object clone(const nanobind::object &ip); + pybind11::object clone(const pybind11::object &ip); private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, - nanobind::object parentKeepAlive); + pybind11::object parentKeepAlive); MlirOperation operation; - nanobind::handle handle; + pybind11::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or // Module. // TODO: As implemented, this facility is only sufficient for modeling the // trivial module parent back-reference. Generalize this to also account for // transitions from detached to attached and address TODOs in the // ir_operation.py regarding testing corresponding lifetime guarantees. - nanobind::object parentKeepAlive; + pybind11::object parentKeepAlive; bool attached = true; bool valid = true; @@ -732,17 +733,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// python types. class PyOpView : public PyOperationBase { public: - PyOpView(const nanobind::object &operationObject); + PyOpView(const pybind11::object &operationObject); PyOperation &getOperation() override { return operation; } - nanobind::object getOperationObject() { return operationObject; } + pybind11::object getOperationObject() { return operationObject; } - static nanobind::object buildGeneric( - const nanobind::object &cls, std::optional resultTypeList, - nanobind::list operandList, std::optional attributes, + static pybind11::object buildGeneric( + const pybind11::object &cls, std::optional resultTypeList, + pybind11::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const nanobind::object &maybeIp); + const pybind11::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor @@ -751,12 +752,12 @@ class PyOpView : public PyOperationBase { /// /// The caller is responsible for verifying that `operation` is a valid /// operation to construct `cls` with. - static nanobind::object constructDerived(const nanobind::object &cls, - const nanobind::object &operation); + static pybind11::object constructDerived(const pybind11::object &cls, + const PyOperation &operation); private: PyOperation &operation; // For efficient, cast-free access from C++ - nanobind::object operationObject; // Holds the reference. + pybind11::object operationObject; // Holds the reference. }; /// Wrapper around an MlirRegion. @@ -829,7 +830,7 @@ class PyBlock { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirBlock. - nanobind::object getCapsule(); + pybind11::object getCapsule(); private: PyOperationRef parentOperation; @@ -857,10 +858,10 @@ class PyInsertionPoint { void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. - static nanobind::object contextEnter(nanobind::object insertionPoint); - void contextExit(const nanobind::object &excType, - const nanobind::object &excVal, - const nanobind::object &excTb); + pybind11::object contextEnter(); + void contextExit(const pybind11::object &excType, + const pybind11::object &excVal, + const pybind11::object &excTb); PyBlock &getBlock() { return block; } std::optional &getRefOperation() { return refOperation; } @@ -885,13 +886,13 @@ class PyType : public BaseContextObject { MlirType get() const { return type; } /// Gets a capsule wrapping the void* within the MlirType. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyType from the MlirType wrapped by a capsule. /// Note that PyType instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirType /// is taken by calling this function. - static PyType createFromCapsule(nanobind::object capsule); + static PyType createFromCapsule(pybind11::object capsule); private: MlirType type; @@ -911,10 +912,10 @@ class PyTypeID { MlirTypeID get() { return typeID; } /// Gets a capsule wrapping the void* within the MlirTypeID. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. - static PyTypeID createFromCapsule(nanobind::object capsule); + static PyTypeID createFromCapsule(pybind11::object capsule); private: MlirTypeID typeID; @@ -931,7 +932,7 @@ class PyConcreteType : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = nanobind::class_; + using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirType); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -944,38 +945,34 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = - nanobind::cast(nanobind::repr(nanobind::cast(orig))); - throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + " (from " + - origRepr + ")") - .str() - .c_str()); + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw py::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } - static void bind(nanobind::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), - nanobind::arg("cast_from_type")); + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), + pybind11::arg("cast_from_type")); cls.def_static( "isinstance", [](PyType &otherType) -> bool { return DerivedTy::isaFunction(otherType); }, - nanobind::arg("other")); - cls.def_prop_ro_static( - "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + pybind11::arg("other")); + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw nanobind::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) - .str() - .c_str()); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); }); - cls.def_prop_ro("typeid", [](PyType &self) { - return nanobind::cast(nanobind::cast(self).attr("typeid")); + cls.def_property_readonly("typeid", [](PyType &self) { + return py::cast(self).attr("typeid").cast(); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -989,8 +986,8 @@ class PyConcreteType : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - nanobind::cast(nanobind::cpp_function( - [](PyType pyType) -> DerivedTy { return pyType; }))); + pybind11::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; })); } DerivedTy::bindDerived(cls); @@ -1011,13 +1008,13 @@ class PyAttribute : public BaseContextObject { MlirAttribute get() const { return attr; } /// Gets a capsule wrapping the void* within the MlirAttribute. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(nanobind::object capsule); + static PyAttribute createFromCapsule(pybind11::object capsule); private: MlirAttribute attr; @@ -1057,7 +1054,7 @@ class PyConcreteAttribute : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = nanobind::class_; + using ClassTy = pybind11::class_; using IsAFunctionTy = bool (*)(MlirAttribute); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -1070,45 +1067,37 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = - nanobind::cast(nanobind::repr(nanobind::cast(orig))); - throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + - origRepr + ")") - .str() - .c_str()); + auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); + throw py::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str()); } return orig; } - static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { - ClassTy cls; - if (slots) { - cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); - } else { - cls = ClassTy(m, DerivedTy::pyClassName); - } - cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), - nanobind::arg("cast_from_attr")); + static void bind(pybind11::module &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), + pybind11::module_local()); + cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), + pybind11::arg("cast_from_attr")); cls.def_static( "isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }, - nanobind::arg("other")); - cls.def_prop_ro( + pybind11::arg("other")); + cls.def_property_readonly( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); - cls.def_prop_ro_static( - "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { + cls.def_property_readonly_static( + "static_typeid", [](py::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw nanobind::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) - .str() - .c_str()); + throw py::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); }); - cls.def_prop_ro("typeid", [](PyAttribute &self) { - return nanobind::cast(nanobind::cast(self).attr("typeid")); + cls.def_property_readonly("typeid", [](PyAttribute &self) { + return py::cast(self).attr("typeid").cast(); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1123,10 +1112,9 @@ class PyConcreteAttribute : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - nanobind::cast( - nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { - return pyAttribute; - }))); + pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + })); } DerivedTy::bindDerived(cls); @@ -1158,13 +1146,13 @@ class PyValue { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirValue. - nanobind::object getCapsule(); + pybind11::object getCapsule(); - nanobind::object maybeDownCast(); + pybind11::object maybeDownCast(); /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. - static PyValue createFromCapsule(nanobind::object capsule); + static PyValue createFromCapsule(pybind11::object capsule); private: PyOperationRef parentOperation; @@ -1181,13 +1169,13 @@ class PyAffineExpr : public BaseContextObject { MlirAffineExpr get() const { return affineExpr; } /// Gets a capsule wrapping the void* within the MlirAffineExpr. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(nanobind::object capsule); + static PyAffineExpr createFromCapsule(pybind11::object capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1208,13 +1196,13 @@ class PyAffineMap : public BaseContextObject { MlirAffineMap get() const { return affineMap; } /// Gets a capsule wrapping the void* within the MlirAffineMap. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(nanobind::object capsule); + static PyAffineMap createFromCapsule(pybind11::object capsule); private: MlirAffineMap affineMap; @@ -1229,12 +1217,12 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. - nanobind::object getCapsule(); + pybind11::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(nanobind::object capsule); + static PyIntegerSet createFromCapsule(pybind11::object capsule); private: MlirIntegerSet integerSet; @@ -1251,7 +1239,7 @@ class PySymbolTable { /// Returns the symbol (opview) with the given name, throws if there is no /// such symbol in the table. - nanobind::object dunderGetItem(const std::string &name); + pybind11::object dunderGetItem(const std::string &name); /// Removes the given operation from the symbol table and erases it. void erase(PyOperationBase &symbol); @@ -1281,7 +1269,7 @@ class PySymbolTable { /// Walks all symbol tables under and including 'from'. static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - nanobind::object callback); + pybind11::object callback); /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } @@ -1301,16 +1289,16 @@ struct MLIRError { std::vector errorDiagnostics; }; -void populateIRAffine(nanobind::module_ &m); -void populateIRAttributes(nanobind::module_ &m); -void populateIRCore(nanobind::module_ &m); -void populateIRInterfaces(nanobind::module_ &m); -void populateIRTypes(nanobind::module_ &m); +void populateIRAffine(pybind11::module &m); +void populateIRAttributes(pybind11::module &m); +void populateIRCore(pybind11::module &m); +void populateIRInterfaces(pybind11::module &m); +void populateIRTypes(pybind11::module &m); } // namespace python } // namespace mlir -namespace nanobind { +namespace pybind11 { namespace detail { template <> @@ -1321,6 +1309,6 @@ struct type_caster : MlirDefaultingCaster {}; } // namespace detail -} // namespace nanobind +} // namespace pybind11 #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 5cfa51142..6f192bc4b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,26 +6,19 @@ // //===----------------------------------------------------------------------===// -// clang-format off #include "IRModule.h" -#include "mlir/Bindings/Python/IRTypes.h" -// clang-format on -#include -#include -#include -#include -#include +#include "PybindUtils.h" -#include +#include "mlir/Bindings/Python/IRTypes.h" -#include "IRModule.h" -#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" -namespace nb = nanobind; +#include + +namespace py = pybind11; using namespace mlir; using namespace mlir::python; @@ -55,7 +48,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + py::arg("width"), py::arg("context") = py::none(), "Create a signless integer type"); c.def_static( "get_signed", @@ -63,7 +56,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeSignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + py::arg("width"), py::arg("context") = py::none(), "Create a signed integer type"); c.def_static( "get_unsigned", @@ -71,25 +64,25 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - nb::arg("width"), nb::arg("context").none() = nb::none(), + py::arg("width"), py::arg("context") = py::none(), "Create an unsigned integer type"); - c.def_prop_ro( + c.def_property_readonly( "width", [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, "Returns the width of the integer type"); - c.def_prop_ro( + c.def_property_readonly( "is_signless", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSignless(self); }, "Returns whether this is a signless integer"); - c.def_prop_ro( + c.def_property_readonly( "is_signed", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, "Returns whether this is a signed integer"); - c.def_prop_ro( + c.def_property_readonly( "is_unsigned", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsUnsigned(self); @@ -114,7 +107,7 @@ class PyIndexType : public PyConcreteType { MlirType t = mlirIndexTypeGet(context->get()); return PyIndexType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a index type."); + py::arg("context") = py::none(), "Create a index type."); } }; @@ -125,7 +118,7 @@ class PyFloatType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_prop_ro( + c.def_property_readonly( "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, "Returns the width of the floating-point type"); } @@ -148,7 +141,7 @@ class PyFloat4E2M1FNType MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); return PyFloat4E2M1FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); + py::arg("context") = py::none(), "Create a float4_e2m1fn type."); } }; @@ -169,7 +162,7 @@ class PyFloat6E2M3FNType MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); return PyFloat6E2M3FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); + py::arg("context") = py::none(), "Create a float6_e2m3fn type."); } }; @@ -190,7 +183,7 @@ class PyFloat6E3M2FNType MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); return PyFloat6E3M2FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); + py::arg("context") = py::none(), "Create a float6_e3m2fn type."); } }; @@ -211,7 +204,7 @@ class PyFloat8E4M3FNType MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); return PyFloat8E4M3FNType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); + py::arg("context") = py::none(), "Create a float8_e4m3fn type."); } }; @@ -231,7 +224,7 @@ class PyFloat8E5M2Type : public PyConcreteType { MlirType t = mlirFloat8E5M2TypeGet(context->get()); return PyFloat8E5M2Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); + py::arg("context") = py::none(), "Create a float8_e5m2 type."); } }; @@ -251,7 +244,7 @@ class PyFloat8E4M3Type : public PyConcreteType { MlirType t = mlirFloat8E4M3TypeGet(context->get()); return PyFloat8E4M3Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); + py::arg("context") = py::none(), "Create a float8_e4m3 type."); } }; @@ -272,8 +265,7 @@ class PyFloat8E4M3FNUZType MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); return PyFloat8E4M3FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3fnuz type."); + py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); } }; @@ -294,8 +286,7 @@ class PyFloat8E4M3B11FNUZType MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); return PyFloat8E4M3B11FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e4m3b11fnuz type."); + py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); } }; @@ -316,8 +307,7 @@ class PyFloat8E5M2FNUZType MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); return PyFloat8E5M2FNUZType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e5m2fnuz type."); + py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); } }; @@ -337,7 +327,7 @@ class PyFloat8E3M4Type : public PyConcreteType { MlirType t = mlirFloat8E3M4TypeGet(context->get()); return PyFloat8E3M4Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); + py::arg("context") = py::none(), "Create a float8_e3m4 type."); } }; @@ -358,8 +348,7 @@ class PyFloat8E8M0FNUType MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); return PyFloat8E8M0FNUType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), - "Create a float8_e8m0fnu type."); + py::arg("context") = py::none(), "Create a float8_e8m0fnu type."); } }; @@ -379,7 +368,7 @@ class PyBF16Type : public PyConcreteType { MlirType t = mlirBF16TypeGet(context->get()); return PyBF16Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a bf16 type."); + py::arg("context") = py::none(), "Create a bf16 type."); } }; @@ -399,7 +388,7 @@ class PyF16Type : public PyConcreteType { MlirType t = mlirF16TypeGet(context->get()); return PyF16Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f16 type."); + py::arg("context") = py::none(), "Create a f16 type."); } }; @@ -419,7 +408,7 @@ class PyTF32Type : public PyConcreteType { MlirType t = mlirTF32TypeGet(context->get()); return PyTF32Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a tf32 type."); + py::arg("context") = py::none(), "Create a tf32 type."); } }; @@ -439,7 +428,7 @@ class PyF32Type : public PyConcreteType { MlirType t = mlirF32TypeGet(context->get()); return PyF32Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f32 type."); + py::arg("context") = py::none(), "Create a f32 type."); } }; @@ -459,7 +448,7 @@ class PyF64Type : public PyConcreteType { MlirType t = mlirF64TypeGet(context->get()); return PyF64Type(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a f64 type."); + py::arg("context") = py::none(), "Create a f64 type."); } }; @@ -479,7 +468,7 @@ class PyNoneType : public PyConcreteType { MlirType t = mlirNoneTypeGet(context->get()); return PyNoneType(context->getRef(), t); }, - nb::arg("context").none() = nb::none(), "Create a none type."); + py::arg("context") = py::none(), "Create a none type."); } }; @@ -501,15 +490,14 @@ class PyComplexType : public PyConcreteType { MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw nb::value_error( + throw py::value_error( (Twine("invalid '") + - nb::cast(nb::repr(nb::cast(elementType))) + + py::repr(py::cast(elementType)).cast() + "' and expected floating point or integer type.") - .str() - .c_str()); + .str()); }, "Create a complex type"); - c.def_prop_ro( + c.def_property_readonly( "element_type", [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, "Returns element type."); @@ -520,22 +508,22 @@ class PyComplexType : public PyConcreteType { // Shaped Type Interface - ShapedType void mlir::PyShapedType::bindDerived(ClassTy &c) { - c.def_prop_ro( + c.def_property_readonly( "element_type", [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); - c.def_prop_ro( + c.def_property_readonly( "has_rank", [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, "Returns whether the given shaped type is ranked."); - c.def_prop_ro( + c.def_property_readonly( "rank", [](PyShapedType &self) { self.requireHasRank(); return mlirShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); - c.def_prop_ro( + c.def_property_readonly( "has_static_shape", [](PyShapedType &self) -> bool { return mlirShapedTypeHasStaticShape(self); @@ -547,7 +535,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicDim(self, dim); }, - nb::arg("dim"), + py::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); c.def( @@ -556,12 +544,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeGetDimSize(self, dim); }, - nb::arg("dim"), + py::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - nb::arg("dim_size"), + py::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( @@ -570,10 +558,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicStrideOrOffset(val); }, - nb::arg("dim_size"), + py::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); - c.def_prop_ro( + c.def_property_readonly( "shape", [](PyShapedType &self) { self.requireHasRank(); @@ -599,7 +587,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { void mlir::PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw nb::value_error( + throw py::value_error( "calling this method requires that the type has a rank."); } } @@ -619,15 +607,15 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, nb::arg("shape"), - nb::arg("element_type"), nb::kw_only(), - nb::arg("scalable").none() = nb::none(), - nb::arg("scalable_dims").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a vector type") - .def_prop_ro( + c.def_static("get", &PyVectorType::get, py::arg("shape"), + py::arg("element_type"), py::kw_only(), + py::arg("scalable") = py::none(), + py::arg("scalable_dims") = py::none(), + py::arg("loc") = py::none(), "Create a vector type") + .def_property_readonly( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_prop_ro("scalable_dims", [](MlirType self) { + .def_property_readonly("scalable_dims", [](MlirType self) { std::vector scalableDims; size_t rank = static_cast(mlirShapedTypeGetRank(self)); scalableDims.reserve(rank); @@ -639,11 +627,11 @@ class PyVectorType : public PyConcreteType { private: static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, + std::optional scalable, std::optional> scalableDims, DefaultingPyLocation loc) { if (scalable && scalableDims) { - throw nb::value_error("'scalable' and 'scalable_dims' kwargs " + throw py::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); } @@ -651,10 +639,10 @@ class PyVectorType : public PyConcreteType { MlirType type; if (scalable) { if (scalable->size() != shape.size()) - throw nb::value_error("Expected len(scalable) == len(shape)."); + throw py::value_error("Expected len(scalable) == len(shape)."); SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const nb::handle &h) { return nb::cast(h); })); + *scalable, [](const py::handle &h) { return h.cast(); })); type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); @@ -662,7 +650,7 @@ class PyVectorType : public PyConcreteType { SmallVector scalableDimFlags(shape.size(), false); for (int64_t dim : *scalableDims) { if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) - throw nb::value_error("Scalable dimension index out of bounds."); + throw py::value_error("Scalable dimension index out of bounds."); scalableDimFlags[dim] = true; } type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), @@ -701,17 +689,17 @@ class PyRankedTensorType throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("encoding").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); - c.def_prop_ro("encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = - mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + py::arg("shape"), py::arg("element_type"), + py::arg("encoding") = py::none(), py::arg("loc") = py::none(), + "Create a ranked tensor type"); + c.def_property_readonly( + "encoding", + [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); } }; @@ -735,7 +723,7 @@ class PyUnrankedTensorType throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, - nb::arg("element_type"), nb::arg("loc").none() = nb::none(), + py::arg("element_type"), py::arg("loc") = py::none(), "Create a unranked tensor type"); } }; @@ -766,11 +754,10 @@ class PyMemRefType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("layout").none() = nb::none(), - nb::arg("memory_space").none() = nb::none(), - nb::arg("loc").none() = nb::none(), "Create a memref type") - .def_prop_ro( + py::arg("shape"), py::arg("element_type"), + py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), + py::arg("loc") = py::none(), "Create a memref type") + .def_property_readonly( "layout", [](PyMemRefType &self) -> MlirAttribute { return mlirMemRefTypeGetLayout(self); @@ -788,14 +775,14 @@ class PyMemRefType : public PyConcreteType { return {strides, offset}; }, "The strides and offset of the MemRef type.") - .def_prop_ro( + .def_property_readonly( "affine_map", [](PyMemRefType &self) -> PyAffineMap { MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); return PyAffineMap(self.getContext(), map); }, "The layout of the MemRef type as an affine map.") - .def_prop_ro( + .def_property_readonly( "memory_space", [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); @@ -833,9 +820,9 @@ class PyUnrankedMemRefType throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, - nb::arg("element_type"), nb::arg("memory_space").none(), - nb::arg("loc").none() = nb::none(), "Create a unranked memref type") - .def_prop_ro( + py::arg("element_type"), py::arg("memory_space"), + py::arg("loc") = py::none(), "Create a unranked memref type") + .def_property_readonly( "memory_space", [](PyUnrankedMemRefType &self) -> std::optional { MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); @@ -864,15 +851,15 @@ class PyTupleType : public PyConcreteType { elements.data()); return PyTupleType(context->getRef(), t); }, - nb::arg("elements"), nb::arg("context").none() = nb::none(), + py::arg("elements"), py::arg("context") = py::none(), "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { return mlirTupleTypeGetType(self, pos); }, - nb::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_prop_ro( + py::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_property_readonly( "num_types", [](PyTupleType &self) -> intptr_t { return mlirTupleTypeGetNumTypes(self); @@ -900,14 +887,13 @@ class PyFunctionType : public PyConcreteType { results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, - nb::arg("inputs"), nb::arg("results"), - nb::arg("context").none() = nb::none(), + py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), "Gets a FunctionType from a list of input and result types"); - c.def_prop_ro( + c.def_property_readonly( "inputs", [](PyFunctionType &self) { MlirType t = self; - nb::list types; + py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { types.append(mlirFunctionTypeGetInput(t, i)); @@ -915,10 +901,10 @@ class PyFunctionType : public PyConcreteType { return types; }, "Returns the list of input types in the FunctionType."); - c.def_prop_ro( + c.def_property_readonly( "results", [](PyFunctionType &self) { - nb::list types; + py::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { types.append(mlirFunctionTypeGetResult(self, i)); @@ -952,21 +938,21 @@ class PyOpaqueType : public PyConcreteType { toMlirStringRef(typeData)); return PyOpaqueType(context->getRef(), type); }, - nb::arg("dialect_namespace"), nb::arg("buffer"), - nb::arg("context").none() = nb::none(), + py::arg("dialect_namespace"), py::arg("buffer"), + py::arg("context") = py::none(), "Create an unregistered (opaque) dialect type."); - c.def_prop_ro( + c.def_property_readonly( "dialect_namespace", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); + return py::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque type as a string."); - c.def_prop_ro( + c.def_property_readonly( "data", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return nb::str(stringRef.data, stringRef.length); + return py::str(stringRef.data, stringRef.length); }, "Returns the data for the Opaque type as a string."); } @@ -974,7 +960,7 @@ class PyOpaqueType : public PyConcreteType { } // namespace -void mlir::python::populateIRTypes(nb::module_ &m) { +void mlir::python::populateIRTypes(py::module &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index e5e64a921..7c2702190 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,31 +6,29 @@ // //===----------------------------------------------------------------------===// -#include -#include +#include "PybindUtils.h" #include "Globals.h" #include "IRModule.h" -#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" -namespace nb = nanobind; +namespace py = pybind11; using namespace mlir; -using namespace nb::literals; +using namespace py::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -NB_MODULE(_mlir, m) { +PYBIND11_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - nb::class_(m, "_Globals") - .def_prop_rw("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) + py::class_(m, "_Globals", py::module_local()) + .def_property("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) .def( "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { @@ -47,21 +45,22 @@ NB_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, nb::kw_only(), + "operation_name"_a, "operation_class"_a, py::kw_only(), "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); + m.attr("globals") = + py::cast(new PyGlobals, py::return_value_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](nb::type_object pyClass) { + [](py::type pyClass) { std::string dialectNamespace = - nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); + pyClass.attr("DIALECT_NAMESPACE").cast(); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, @@ -69,46 +68,45 @@ NB_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const nb::type_object &dialectClass, bool replace) -> nb::object { - return nb::cpp_function( - [dialectClass, - replace](nb::type_object opClass) -> nb::type_object { + [](const py::type &dialectClass, bool replace) -> py::cpp_function { + return py::cpp_function( + [dialectClass, replace](py::type opClass) -> py::type { std::string operationName = - nanobind::cast(opClass.attr("OPERATION_NAME")); + opClass.attr("OPERATION_NAME").cast(); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); // Dict-stuff the new opClass by name onto the dialect class. - nb::object opClassName = opClass.attr("__name__"); + py::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; return opClass; }); }, - "dialect_class"_a, nb::kw_only(), "replace"_a = false, + "dialect_class"_a, py::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> nb::object { - return nb::cpp_function([mlirTypeID, replace]( - nb::callable typeCaster) -> nb::object { + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function([mlirTypeID, + replace](py::object typeCaster) -> py::object { PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); return typeCaster; }); }, - "typeid"_a, nb::kw_only(), "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> nb::object { - return nb::cpp_function( - [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function( + [mlirTypeID, replace](py::object valueCaster) -> py::object { PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, replace); return valueCaster; }); }, - "typeid"_a, nb::kw_only(), "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index b5dce4fe4..e991deaae 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,16 +8,12 @@ #include "Pass.h" -#include -#include -#include - #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" -namespace nb = nanobind; -using namespace nb::literals; +namespace py = pybind11; +using namespace py::literals; using namespace mlir; using namespace mlir::python; @@ -38,15 +34,16 @@ class PyPassManager { MlirPassManager get() { return passManager; } void release() { passManager.ptr = nullptr; } - nb::object getCapsule() { - return nb::steal(mlirPythonPassManagerToCapsule(get())); + pybind11::object getCapsule() { + return py::reinterpret_steal( + mlirPythonPassManagerToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static pybind11::object createFromCapsule(pybind11::object capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) - throw nb::python_error(); - return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); + throw py::error_already_set(); + return py::cast(PyPassManager(rawPm), py::return_value_policy::move); } private: @@ -56,23 +53,22 @@ class PyPassManager { } // namespace /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { +void mlir::python::populatePassManagerSubmodule(py::module &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - nb::class_(m, "PassManager") - .def( - "__init__", - [](PyPassManager &self, const std::string &anchorOp, - DefaultingPyMlirContext context) { - MlirPassManager passManager = mlirPassManagerCreateOnOperation( - context->get(), - mlirStringRefCreate(anchorOp.data(), anchorOp.size())); - new (&self) PyPassManager(passManager); - }, - "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), - "Create a new PassManager for the current (or provided) Context.") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) + py::class_(m, "PassManager", py::module_local()) + .def(py::init<>([](const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); + return new PyPassManager(passManager); + }), + "anchor_op"_a = py::str("any"), "context"_a = py::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") @@ -105,9 +101,9 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, - "large_elements_limit"_a.none() = nb::none(), - "enable_debug_info"_a = false, "print_generic_op_form"_a = false, - "tree_printing_dir_path"_a.none() = nb::none(), + "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, + "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a = py::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", @@ -125,10 +121,10 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw nb::value_error(errorMsg.join().c_str()); + throw py::value_error(std::string(errorMsg.join())); return new PyPassManager(passManager); }, - "pipeline"_a, "context"_a.none() = nb::none(), + "pipeline"_a, "context"_a = py::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -141,7 +137,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw nb::value_error(errorMsg.join().c_str()); + throw py::value_error(std::string(errorMsg.join())); }, "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index bc4094352..3a500d5e8 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "NanobindUtils.h" +#include "PybindUtils.h" namespace mlir { namespace python { -void populatePassManagerSubmodule(nanobind::module_ &m); +void populatePassManagerSubmodule(pybind11::module &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h similarity index 85% rename from mlir/lib/Bindings/Python/NanobindUtils.h rename to mlir/lib/Bindings/Python/PybindUtils.h index 3b0f7f698..38462ac8b 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -1,5 +1,4 @@ -//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ -//-*-===// +//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -10,21 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H -#include - #include "mlir-c/Support.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" -template <> -struct std::iterator_traits { - using value_type = nanobind::handle; - using reference = const value_type; - using pointer = void; - using difference_type = std::ptrdiff_t; - using iterator_category = std::forward_iterator_tag; -}; +#include +#include namespace mlir { namespace python { @@ -63,14 +54,14 @@ class Defaulting { } // namespace python } // namespace mlir -namespace nanobind { +namespace pybind11 { namespace detail { template struct MlirDefaultingCaster { - NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)); + PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + bool load(pybind11::handle src, bool) { if (src.is_none()) { // Note that we do want an exception to propagate from here as it will be // the most informative. @@ -85,20 +76,20 @@ struct MlirDefaultingCaster { // code to produce nice error messages (other than "Cannot cast..."). try { value = DefaultingTy{ - nanobind::cast(src)}; + pybind11::cast(src)}; return true; } catch (std::exception &) { return false; } } - static handle from_cpp(DefaultingTy src, rv_policy policy, - cleanup_list *cleanup) noexcept { - return nanobind::cast(src, policy); + static handle cast(DefaultingTy src, return_value_policy policy, + handle parent) { + return pybind11::cast(src, policy); } }; } // namespace detail -} // namespace nanobind +} // namespace pybind11 //------------------------------------------------------------------------------ // Conversion utilities. @@ -109,7 +100,7 @@ namespace mlir { /// Accumulates into a python string from a method that accepts an /// MlirStringCallback. struct PyPrintAccumulator { - nanobind::list parts; + pybind11::list parts; void *getUserData() { return this; } @@ -117,15 +108,15 @@ struct PyPrintAccumulator { return [](MlirStringRef part, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); - nanobind::str pyPart(part.data, + pybind11::str pyPart(part.data, part.length); // Decodes as UTF-8 by default. printAccum->parts.append(std::move(pyPart)); }; } - nanobind::str join() { - nanobind::str delim("", 0); - return nanobind::cast(delim.attr("join")(parts)); + pybind11::str join() { + pybind11::str delim("", 0); + return delim.attr("join")(parts); } }; @@ -133,21 +124,21 @@ struct PyPrintAccumulator { /// or binary. class PyFileAccumulator { public: - PyFileAccumulator(const nanobind::object &fileObject, bool binary) + PyFileAccumulator(const pybind11::object &fileObject, bool binary) : pyWriteFunction(fileObject.attr("write")), binary(binary) {} void *getUserData() { return this; } MlirStringCallback getCallback() { return [](MlirStringRef part, void *userData) { - nanobind::gil_scoped_acquire acquire; + pybind11::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. - nanobind::bytes pyBytes(part.data, part.length); + pybind11::bytes pyBytes(part.data, part.length); accum->pyWriteFunction(pyBytes); } else { - nanobind::str pyStr(part.data, + pybind11::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. accum->pyWriteFunction(pyStr); } @@ -155,7 +146,7 @@ class PyFileAccumulator { } private: - nanobind::object pyWriteFunction; + pybind11::object pyWriteFunction; bool binary; }; @@ -172,17 +163,17 @@ struct PySinglePartStringAccumulator { assert(!accum->invoked && "PySinglePartStringAccumulator called back multiple times"); accum->invoked = true; - accum->value = nanobind::str(part.data, part.length); + accum->value = pybind11::str(part.data, part.length); }; } - nanobind::str takeValue() { + pybind11::str takeValue() { assert(invoked && "PySinglePartStringAccumulator not called back"); return std::move(value); } private: - nanobind::str value; + pybind11::str value; bool invoked = false; }; @@ -217,7 +208,7 @@ struct PySinglePartStringAccumulator { template class Sliceable { protected: - using ClassTy = nanobind::class_; + using ClassTy = pybind11::class_; /// Transforms `index` into a legal value to access the underlying sequence. /// Returns <0 on failure. @@ -246,7 +237,7 @@ class Sliceable { /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - nanobind::object getItem(intptr_t index) { + pybind11::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { @@ -259,20 +250,20 @@ class Sliceable { ->getRawElement(linearizeIndex(index)) .maybeDownCast(); else - return nanobind::cast( + return pybind11::cast( static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given /// slice. Returns a nullptr object on failure. - nanobind::object getItemSlice(PyObject *slice) { + pybind11::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, &sliceLength) != 0) { PyErr_SetString(PyExc_IndexError, "index out of range"); return {}; } - return nanobind::cast(static_cast(this)->slice( + return pybind11::cast(static_cast(this)->slice( startIndex + start * step, sliceLength, step * extraStep)); } @@ -288,7 +279,7 @@ class Sliceable { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { - throw nanobind::index_error("index out of range"); + throw pybind11::index_error("index out of range"); } return static_cast(this)->getRawElement(linearizeIndex(index)); @@ -313,38 +304,39 @@ class Sliceable { } /// Binds the indexing and length methods in the Python class. - static void bind(nanobind::module_ &m) { - auto clazz = nanobind::class_(m, Derived::pyClassName) + static void bind(pybind11::module &m) { + auto clazz = pybind11::class_(m, Derived::pyClassName, + pybind11::module_local()) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); // Manually implement the sequence protocol via the C API. We do this - // because it is approx 4x faster than via nanobind, largely because that + // because it is approx 4x faster than via pybind11, largely because that // formulation requires a C++ exception to be thrown to detect end of // sequence. // Since we are in a C-context, any C++ exception that happens here // will terminate the program. There is nothing in this implementation // that should throw in a non-terminal way, so we forgo further // exception marshalling. - // See: https://github.com/pybind/nanobind/issues/2842 + // See: https://github.com/pybind/pybind11/issues/2842 auto heap_type = reinterpret_cast(clazz.ptr()); assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && "must be heap type"); heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { - auto self = nanobind::cast(nanobind::handle(rawSelf)); + auto self = pybind11::cast(rawSelf); return self->length; }; // sq_item is called as part of the sequence protocol for iteration, // list construction, etc. heap_type->as_sequence.sq_item = +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { - auto self = nanobind::cast(nanobind::handle(rawSelf)); + auto self = pybind11::cast(rawSelf); return self->getItem(index).release().ptr(); }; // mp_subscript is used for both slices and integer lookups. heap_type->as_mapping.mp_subscript = +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { - auto self = nanobind::cast(nanobind::handle(rawSelf)); + auto self = pybind11::cast(rawSelf); Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); if (!PyErr_Occurred()) { // Integer indexing. diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index b2c1de4be..1d8128be9 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,16 +8,14 @@ #include "Rewrite.h" -#include - #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Rewrite.h" #include "mlir/Config/mlir-config.h" -namespace nb = nanobind; +namespace py = pybind11; using namespace mlir; -using namespace nb::literals; +using namespace py::literals; using namespace mlir::python; namespace { @@ -56,17 +54,18 @@ class PyFrozenRewritePatternSet { } MlirFrozenRewritePatternSet get() { return set; } - nb::object getCapsule() { - return nb::steal( + pybind11::object getCapsule() { + return py::reinterpret_steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static nb::object createFromCapsule(nb::object capsule) { + static pybind11::object createFromCapsule(pybind11::object capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) - throw nb::python_error(); - return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); + throw py::error_already_set(); + return py::cast(PyFrozenRewritePatternSet(rawPm), + py::return_value_policy::move); } private: @@ -76,27 +75,25 @@ class PyFrozenRewritePatternSet { } // namespace /// Create the `mlir.rewrite` here. -void mlir::python::populateRewriteSubmodule(nb::module_ &m) { +void mlir::python::populateRewriteSubmodule(py::module &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - nb::class_(m, "PDLModule") - .def( - "__init__", - [](PyPDLPatternModule &self, MlirModule module) { - new (&self) - PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); - }, - "module"_a, "Create a PDL module from the given module.") + py::class_(m, "PDLModule", py::module_local()) + .def(py::init<>([](MlirModule module) { + return mlirPDLPatternModuleFromModule(module); + }), + "module"_a, "Create a PDL module from the given module.") .def("freeze", [](PyPDLPatternModule &self) { return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }); -#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH - nb::class_(m, "FrozenRewritePatternSet") - .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyFrozenRewritePatternSet::getCapsule) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg + py::class_(m, "FrozenRewritePatternSet", + py::module_local()) + .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( @@ -105,7 +102,7 @@ void mlir::python::populateRewriteSubmodule(nb::module_ &m) { auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); if (mlirLogicalResultIsFailure(status)) // FIXME: Not sure this is the right error to throw here. - throw nb::value_error("pattern application failed to converge"); + throw py::value_error("pattern application failed to converge"); }, "module"_a, "set"_a, "Applys the given patterns to the given module greedily while folding " diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index ae89e2b95..997b80add 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "NanobindUtils.h" +#include "PybindUtils.h" namespace mlir { namespace python { -void populateRewriteSubmodule(nanobind::module_ &m); +void populateRewriteSubmodule(pybind11::module &m); } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b865c9032..10866c11b 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -448,7 +448,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" - PYTHON_BINDINGS_LIBRARY nanobind SOURCES MainModule.cpp IRAffine.cpp @@ -464,7 +463,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Globals.h IRModule.h Pass.h - NanobindUtils.h + PybindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index f240d6ef9..ab8a91229 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -nanobind>=2.4, <3.0 +nanobind>=2.0, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From 4bcfad0b482858f2c65db23df5bf22ddc78609a5 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 18 Dec 2024 21:55:42 -0500 Subject: [PATCH 810/915] [mlir python] Port Python core code to nanobind. (#120473) Relands #118583, with a fix for Python 3.8 compatibility. It was not possible to set the buffer protocol accessers via slots in Python 3.8. Why? https://nanobind.readthedocs.io/en/latest/why.html says it better than I can, but my primary motivation for this change is to improve MLIR IR construction time from JAX. For a complicated Google-internal LLM model in JAX, this change improves the MLIR lowering time by around 5s (out of around 30s), which is a significant speedup for simply switching binding frameworks. To a large extent, this is a mechanical change, for instance changing `pybind11::` to `nanobind::`. Notes: * this PR needs Nanobind 2.4.0, because it needs a bug fix (https://github.com/wjakob/nanobind/pull/806) that landed in that release. * this PR does not port the in-tree dialect extension modules. They can be ported in a future PR. * I removed the py::sibling() annotations from def_static and def_class in `PybindAdapters.h`. These ask pybind11 to try to form an overload with an existing method, but it's not possible to form mixed pybind11/nanobind overloads this ways and the parent class is now defined in nanobind. Better solutions may be possible here. * nanobind does not contain an exact equivalent of pybind11's buffer protocol support. It was not hard to add a nanobind implementation of a similar API. * nanobind is pickier about casting to std::vector, expecting that the input is a sequence of bool types, not truthy values. In a couple of places I added code to support truthy values during casting. * nanobind distinguishes bytes (`nb::bytes`) from strings (e.g., `std::string`). This required nb::bytes overloads in a few places. --- mlir/include/mlir/Bindings/Python/IRTypes.h | 2 +- .../mlir/Bindings/Python/NanobindAdaptors.h | 26 +- .../mlir/Bindings/Python/PybindAdaptors.h | 10 +- mlir/lib/Bindings/Python/Globals.h | 39 +- mlir/lib/Bindings/Python/IRAffine.cpp | 265 ++-- mlir/lib/Bindings/Python/IRAttributes.cpp | 681 +++++--- mlir/lib/Bindings/Python/IRCore.cpp | 1412 +++++++++-------- mlir/lib/Bindings/Python/IRInterfaces.cpp | 171 +- mlir/lib/Bindings/Python/IRModule.cpp | 57 +- mlir/lib/Bindings/Python/IRModule.h | 332 ++-- mlir/lib/Bindings/Python/IRTypes.cpp | 200 +-- mlir/lib/Bindings/Python/MainModule.cpp | 56 +- .../Python/{PybindUtils.h => NanobindUtils.h} | 84 +- mlir/lib/Bindings/Python/Pass.cpp | 58 +- mlir/lib/Bindings/Python/Pass.h | 4 +- mlir/lib/Bindings/Python/Rewrite.cpp | 43 +- mlir/lib/Bindings/Python/Rewrite.h | 4 +- mlir/python/CMakeLists.txt | 9 +- mlir/python/requirements.txt | 2 +- 19 files changed, 1881 insertions(+), 1574 deletions(-) rename mlir/lib/Bindings/Python/{PybindUtils.h => NanobindUtils.h} (85%) diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index 9afad4c23..ba9642cf2 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,7 +9,7 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace mlir { diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 5e01cebcb..943981b1f 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -64,7 +64,7 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { /// Casts object <-> MlirAffineMap. template <> struct type_caster { - NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")); + NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToAffineMap(capsule.ptr()); @@ -87,7 +87,7 @@ struct type_caster { /// Casts object <-> MlirAttribute. template <> struct type_caster { - NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")); + NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToAttribute(capsule.ptr()); @@ -108,7 +108,7 @@ struct type_caster { /// Casts object -> MlirBlock. template <> struct type_caster { - NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")); + NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToBlock(capsule.ptr()); @@ -119,7 +119,7 @@ struct type_caster { /// Casts object -> MlirContext. template <> struct type_caster { - NB_TYPE_CASTER(MlirContext, const_name("MlirContext")); + NB_TYPE_CASTER(MlirContext, const_name("MlirContext")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { // Gets the current thread-bound context. @@ -139,7 +139,7 @@ struct type_caster { /// Casts object <-> MlirDialectRegistry. template <> struct type_caster { - NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")); + NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); @@ -159,7 +159,7 @@ struct type_caster { /// Casts object <-> MlirLocation. template <> struct type_caster { - NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")); + NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { // Gets the current thread-bound context. @@ -185,7 +185,7 @@ struct type_caster { /// Casts object <-> MlirModule. template <> struct type_caster { - NB_TYPE_CASTER(MlirModule, const_name("MlirModule")); + NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToModule(capsule.ptr()); @@ -206,7 +206,7 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirFrozenRewritePatternSet, - const_name("MlirFrozenRewritePatternSet")); + const_name("MlirFrozenRewritePatternSet")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); @@ -225,7 +225,7 @@ struct type_caster { /// Casts object <-> MlirOperation. template <> struct type_caster { - NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")); + NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToOperation(capsule.ptr()); @@ -247,7 +247,7 @@ struct type_caster { /// Casts object <-> MlirValue. template <> struct type_caster { - NB_TYPE_CASTER(MlirValue, const_name("MlirValue")); + NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToValue(capsule.ptr()); @@ -270,7 +270,7 @@ struct type_caster { /// Casts object -> MlirPassManager. template <> struct type_caster { - NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")); + NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToPassManager(capsule.ptr()); @@ -281,7 +281,7 @@ struct type_caster { /// Casts object <-> MlirTypeID. template <> struct type_caster { - NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")); + NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToTypeID(capsule.ptr()); @@ -303,7 +303,7 @@ struct type_caster { /// Casts object <-> MlirType. template <> struct type_caster { - NB_TYPE_CASTER(MlirType, const_name("MlirType")); + NB_TYPE_CASTER(MlirType, const_name("MlirType")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { nanobind::object capsule = mlirApiObjectToCapsule(src); value = mlirPythonCapsuleToType(capsule.ptr()); diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index c8233355d..edc69774b 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -374,9 +374,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_staticmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::staticmethod(cf); return *this; } @@ -387,9 +386,8 @@ class pure_subclass { static_assert(!std::is_member_function_pointer::value, "def_classmethod(...) called with a non-static member " "function pointer"); - py::cpp_function cf( - std::forward(f), py::name(name), py::scope(thisClass), - py::sibling(py::getattr(thisClass, name, py::none())), extra...); + py::cpp_function cf(std::forward(f), py::name(name), + py::scope(thisClass), extra...); thisClass.attr(cf.name()) = py::reinterpret_borrow(PyClassMethod_New(cf.ptr())); return *this; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index a022067f5..0ec522d14 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -9,18 +9,17 @@ #ifndef MLIR_BINDINGS_PYTHON_GLOBALS_H #define MLIR_BINDINGS_PYTHON_GLOBALS_H -#include "PybindUtils.h" +#include +#include +#include +#include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" -#include -#include -#include - namespace mlir { namespace python { @@ -57,55 +56,55 @@ class PyGlobals { /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerAttributeBuilder(const std::string &attributeKind, - pybind11::function pyFunc, + nanobind::callable pyFunc, bool replace = false); /// Adds a user-friendly type caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. - void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, + void registerTypeCaster(MlirTypeID mlirTypeID, nanobind::callable typeCaster, bool replace = false); /// Adds a user-friendly value caster. Raises an exception if the mapping /// already exists and replace == false. This is intended to be called by /// implementation code. void registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, + nanobind::callable valueCaster, bool replace = false); /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. void registerDialectImpl(const std::string &dialectNamespace, - pybind11::object pyClass); + nanobind::object pyClass); /// Adds a concrete implementation operation class. /// Raises an exception if the mapping already exists and replace == false. /// This is intended to be called by implementation code. void registerOperationImpl(const std::string &operationName, - pybind11::object pyClass, bool replace = false); + nanobind::object pyClass, bool replace = false); /// Returns the custom Attribute builder for Attribute kind. - std::optional + std::optional lookupAttributeBuilder(const std::string &attributeKind); /// Returns the custom type caster for MlirTypeID mlirTypeID. - std::optional lookupTypeCaster(MlirTypeID mlirTypeID, + std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Returns the custom value caster for MlirTypeID mlirTypeID. - std::optional lookupValueCaster(MlirTypeID mlirTypeID, + std::optional lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect); /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. - std::optional + std::optional lookupDialectClass(const std::string &dialectNamespace); /// Looks up a registered operation class (deriving from OpView) by operation /// name. Note that this may trigger a load of the dialect, which can /// arbitrarily re-enter. - std::optional + std::optional lookupOperationClass(llvm::StringRef operationName); private: @@ -113,15 +112,15 @@ class PyGlobals { /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. - llvm::StringMap dialectClassMap; + llvm::StringMap dialectClassMap; /// Map of full operation name to external operation class object. - llvm::StringMap operationClassMap; + llvm::StringMap operationClassMap; /// Map of attribute ODS name to custom builder. - llvm::StringMap attributeBuilderMap; + llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. - llvm::DenseMap typeCasterMap; + llvm::DenseMap typeCasterMap; /// Map of MlirTypeID to custom value caster. - llvm::DenseMap valueCasterMap; + llvm::DenseMap valueCasterMap; /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index b138e131e..2db690309 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,20 +6,19 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include +#include #include #include #include #include "IRModule.h" - -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Bindings/Python/Interop.h" @@ -30,7 +29,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -46,23 +45,23 @@ static const char kDumpDocstring[] = /// Throws errors in case of failure, using "action" to describe what the caller /// was attempting to do. template -static void pyListToVector(const py::list &list, +static void pyListToVector(const nb::list &list, llvm::SmallVectorImpl &result, StringRef action) { - result.reserve(py::len(list)); - for (py::handle item : list) { + result.reserve(nb::len(list)); + for (nb::handle item : list) { try { - result.push_back(item.cast()); - } catch (py::cast_error &err) { + result.push_back(nb::cast(item)); + } catch (nb::cast_error &err) { std::string msg = (llvm::Twine("Invalid expression when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = (llvm::Twine("Invalid expression (None?) when ") + action + " (" + err.what() + ")") .str(); - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } } @@ -94,7 +93,7 @@ class PyConcreteAffineExpr : public BaseTy { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirAffineExpr); PyConcreteAffineExpr() = default; @@ -105,24 +104,25 @@ class PyConcreteAffineExpr : public BaseTy { static MlirAffineExpr castFrom(PyAffineExpr &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast affine expression to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast affine expression to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig; } - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::arg("expr")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::arg("expr")); cls.def_static( "isinstance", [](PyAffineExpr &otherAffineExpr) -> bool { return DerivedTy::isaFunction(otherAffineExpr); }, - py::arg("other")); + nb::arg("other")); DerivedTy::bindDerived(cls); } @@ -144,9 +144,9 @@ class PyAffineConstantExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none()); - c.def_property_readonly("value", [](PyAffineConstantExpr &self) { + c.def_static("get", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("value", [](PyAffineConstantExpr &self) { return mlirAffineConstantExprGetValue(self); }); } @@ -164,9 +164,9 @@ class PyAffineDimExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineDimExpr &self) { + c.def_static("get", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineDimExpr &self) { return mlirAffineDimExprGetPosition(self); }); } @@ -184,9 +184,9 @@ class PyAffineSymbolExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_static("get", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none()); - c.def_property_readonly("position", [](PyAffineSymbolExpr &self) { + c.def_static("get", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none()); + c.def_prop_ro("position", [](PyAffineSymbolExpr &self) { return mlirAffineSymbolExprGetPosition(self); }); } @@ -209,8 +209,8 @@ class PyAffineBinaryExpr : public PyConcreteAffineExpr { } static void bindDerived(ClassTy &c) { - c.def_property_readonly("lhs", &PyAffineBinaryExpr::lhs); - c.def_property_readonly("rhs", &PyAffineBinaryExpr::rhs); + c.def_prop_ro("lhs", &PyAffineBinaryExpr::lhs); + c.def_prop_ro("rhs", &PyAffineBinaryExpr::rhs); } }; @@ -365,15 +365,14 @@ bool PyAffineExpr::operator==(const PyAffineExpr &other) const { return mlirAffineExprEqual(affineExpr, other.affineExpr); } -py::object PyAffineExpr::getCapsule() { - return py::reinterpret_steal( - mlirPythonAffineExprToCapsule(*this)); +nb::object PyAffineExpr::getCapsule() { + return nb::steal(mlirPythonAffineExprToCapsule(*this)); } -PyAffineExpr PyAffineExpr::createFromCapsule(py::object capsule) { +PyAffineExpr PyAffineExpr::createFromCapsule(nb::object capsule) { MlirAffineExpr rawAffineExpr = mlirPythonCapsuleToAffineExpr(capsule.ptr()); if (mlirAffineExprIsNull(rawAffineExpr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineExpr( PyMlirContext::forContext(mlirAffineExprGetContext(rawAffineExpr)), rawAffineExpr); @@ -424,14 +423,14 @@ bool PyAffineMap::operator==(const PyAffineMap &other) const { return mlirAffineMapEqual(affineMap, other.affineMap); } -py::object PyAffineMap::getCapsule() { - return py::reinterpret_steal(mlirPythonAffineMapToCapsule(*this)); +nb::object PyAffineMap::getCapsule() { + return nb::steal(mlirPythonAffineMapToCapsule(*this)); } -PyAffineMap PyAffineMap::createFromCapsule(py::object capsule) { +PyAffineMap PyAffineMap::createFromCapsule(nb::object capsule) { MlirAffineMap rawAffineMap = mlirPythonCapsuleToAffineMap(capsule.ptr()); if (mlirAffineMapIsNull(rawAffineMap)) - throw py::error_already_set(); + throw nb::python_error(); return PyAffineMap( PyMlirContext::forContext(mlirAffineMapGetContext(rawAffineMap)), rawAffineMap); @@ -454,11 +453,10 @@ class PyIntegerSetConstraint { bool isEq() { return mlirIntegerSetIsConstraintEq(set, pos); } - static void bind(py::module &m) { - py::class_(m, "IntegerSetConstraint", - py::module_local()) - .def_property_readonly("expr", &PyIntegerSetConstraint::getExpr) - .def_property_readonly("is_eq", &PyIntegerSetConstraint::isEq); + static void bind(nb::module_ &m) { + nb::class_(m, "IntegerSetConstraint") + .def_prop_ro("expr", &PyIntegerSetConstraint::getExpr) + .def_prop_ro("is_eq", &PyIntegerSetConstraint::isEq); } private: @@ -501,27 +499,25 @@ bool PyIntegerSet::operator==(const PyIntegerSet &other) const { return mlirIntegerSetEqual(integerSet, other.integerSet); } -py::object PyIntegerSet::getCapsule() { - return py::reinterpret_steal( - mlirPythonIntegerSetToCapsule(*this)); +nb::object PyIntegerSet::getCapsule() { + return nb::steal(mlirPythonIntegerSetToCapsule(*this)); } -PyIntegerSet PyIntegerSet::createFromCapsule(py::object capsule) { +PyIntegerSet PyIntegerSet::createFromCapsule(nb::object capsule) { MlirIntegerSet rawIntegerSet = mlirPythonCapsuleToIntegerSet(capsule.ptr()); if (mlirIntegerSetIsNull(rawIntegerSet)) - throw py::error_already_set(); + throw nb::python_error(); return PyIntegerSet( PyMlirContext::forContext(mlirIntegerSetGetContext(rawIntegerSet)), rawIntegerSet); } -void mlir::python::populateIRAffine(py::module &m) { +void mlir::python::populateIRAffine(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineExpr and derived classes. //---------------------------------------------------------------------------- - py::class_(m, "AffineExpr", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineExpr::getCapsule) + nb::class_(m, "AffineExpr") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineExpr::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineExpr::createFromCapsule) .def("__add__", &PyAffineAddExpr::get) .def("__add__", &PyAffineAddExpr::getRHSConstant) @@ -558,7 +554,7 @@ void mlir::python::populateIRAffine(py::module &m) { .def("__eq__", [](PyAffineExpr &self, PyAffineExpr &other) { return self == other; }) .def("__eq__", - [](PyAffineExpr &self, py::object &other) { return false; }) + [](PyAffineExpr &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineExpr &self) { PyPrintAccumulator printAccum; @@ -579,7 +575,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyAffineExpr &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineExpr &self) { return self.getContext().getObject(); }) .def("compose", @@ -632,16 +628,16 @@ void mlir::python::populateIRAffine(py::module &m) { .def_static("get_ceil_div", &PyAffineCeilDivExpr::getRHSConstant, "Gets an affine expression containing the rounded-up result " "of dividing an expression by a constant.") - .def_static("get_constant", &PyAffineConstantExpr::get, py::arg("value"), - py::arg("context") = py::none(), + .def_static("get_constant", &PyAffineConstantExpr::get, nb::arg("value"), + nb::arg("context").none() = nb::none(), "Gets a constant affine expression with the given value.") .def_static( - "get_dim", &PyAffineDimExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_dim", &PyAffineDimExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a dimension at the given position.") .def_static( - "get_symbol", &PyAffineSymbolExpr::get, py::arg("position"), - py::arg("context") = py::none(), + "get_symbol", &PyAffineSymbolExpr::get, nb::arg("position"), + nb::arg("context").none() = nb::none(), "Gets an affine expression of a symbol at the given position.") .def( "dump", [](PyAffineExpr &self) { mlirAffineExprDump(self); }, @@ -659,13 +655,12 @@ void mlir::python::populateIRAffine(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAffineMap. //---------------------------------------------------------------------------- - py::class_(m, "AffineMap", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAffineMap::getCapsule) + nb::class_(m, "AffineMap") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAffineMap::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAffineMap::createFromCapsule) .def("__eq__", [](PyAffineMap &self, PyAffineMap &other) { return self == other; }) - .def("__eq__", [](PyAffineMap &self, py::object &other) { return false; }) + .def("__eq__", [](PyAffineMap &self, nb::object &other) { return false; }) .def("__str__", [](PyAffineMap &self) { PyPrintAccumulator printAccum; @@ -687,7 +682,7 @@ void mlir::python::populateIRAffine(py::module &m) { return static_cast(llvm::hash_value(self.get().ptr)); }) .def_static("compress_unused_symbols", - [](py::list affineMaps, DefaultingPyMlirContext context) { + [](nb::list affineMaps, DefaultingPyMlirContext context) { SmallVector maps; pyListToVector( affineMaps, maps, "attempting to create an AffineMap"); @@ -704,7 +699,7 @@ void mlir::python::populateIRAffine(py::module &m) { res.emplace_back(context->getRef(), m); return res; }) - .def_property_readonly( + .def_prop_ro( "context", [](PyAffineMap &self) { return self.getContext().getObject(); }, "Context that owns the Affine Map") @@ -713,7 +708,7 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t dimCount, intptr_t symbolCount, py::list exprs, + [](intptr_t dimCount, intptr_t symbolCount, nb::list exprs, DefaultingPyMlirContext context) { SmallVector affineExprs; pyListToVector( @@ -723,8 +718,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.size(), affineExprs.data()); return PyAffineMap(context->getRef(), map); }, - py::arg("dim_count"), py::arg("symbol_count"), py::arg("exprs"), - py::arg("context") = py::none(), + nb::arg("dim_count"), nb::arg("symbol_count"), nb::arg("exprs"), + nb::arg("context").none() = nb::none(), "Gets a map with the given expressions as results.") .def_static( "get_constant", @@ -733,7 +728,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapConstantGet(context->get(), value); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an affine map with a single constant result") .def_static( "get_empty", @@ -741,7 +736,7 @@ void mlir::python::populateIRAffine(py::module &m) { MlirAffineMap affineMap = mlirAffineMapEmptyGet(context->get()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("context") = py::none(), "Gets an empty affine map.") + nb::arg("context").none() = nb::none(), "Gets an empty affine map.") .def_static( "get_identity", [](intptr_t nDims, DefaultingPyMlirContext context) { @@ -749,7 +744,7 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMultiDimIdentityGet(context->get(), nDims); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("context").none() = nb::none(), "Gets an identity map with the given number of dimensions.") .def_static( "get_minor_identity", @@ -759,8 +754,8 @@ void mlir::python::populateIRAffine(py::module &m) { mlirAffineMapMinorIdentityGet(context->get(), nDims, nResults); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("n_dims"), py::arg("n_results"), - py::arg("context") = py::none(), + nb::arg("n_dims"), nb::arg("n_results"), + nb::arg("context").none() = nb::none(), "Gets a minor identity map with the given number of dimensions and " "results.") .def_static( @@ -768,13 +763,13 @@ void mlir::python::populateIRAffine(py::module &m) { [](std::vector permutation, DefaultingPyMlirContext context) { if (!isPermutation(permutation)) - throw py::cast_error("Invalid permutation when attempting to " - "create an AffineMap"); + throw std::runtime_error("Invalid permutation when attempting to " + "create an AffineMap"); MlirAffineMap affineMap = mlirAffineMapPermutationGet( context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); }, - py::arg("permutation"), py::arg("context") = py::none(), + nb::arg("permutation"), nb::arg("context").none() = nb::none(), "Gets an affine map that permutes its inputs.") .def( "get_submap", @@ -782,33 +777,33 @@ void mlir::python::populateIRAffine(py::module &m) { intptr_t numResults = mlirAffineMapGetNumResults(self); for (intptr_t pos : resultPos) { if (pos < 0 || pos >= numResults) - throw py::value_error("result position out of bounds"); + throw nb::value_error("result position out of bounds"); } MlirAffineMap affineMap = mlirAffineMapGetSubMap( self, resultPos.size(), resultPos.data()); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("result_positions")) + nb::arg("result_positions")) .def( "get_major_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMajorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "get_minor_submap", [](PyAffineMap &self, intptr_t nResults) { if (nResults >= mlirAffineMapGetNumResults(self)) - throw py::value_error("number of results out of bounds"); + throw nb::value_error("number of results out of bounds"); MlirAffineMap affineMap = mlirAffineMapGetMinorSubMap(self, nResults); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("n_results")) + nb::arg("n_results")) .def( "replace", [](PyAffineMap &self, PyAffineExpr &expression, @@ -818,39 +813,37 @@ void mlir::python::populateIRAffine(py::module &m) { self, expression, replacement, numResultDims, numResultSyms); return PyAffineMap(self.getContext(), affineMap); }, - py::arg("expr"), py::arg("replacement"), py::arg("n_result_dims"), - py::arg("n_result_syms")) - .def_property_readonly( + nb::arg("expr"), nb::arg("replacement"), nb::arg("n_result_dims"), + nb::arg("n_result_syms")) + .def_prop_ro( "is_permutation", [](PyAffineMap &self) { return mlirAffineMapIsPermutation(self); }) - .def_property_readonly("is_projected_permutation", - [](PyAffineMap &self) { - return mlirAffineMapIsProjectedPermutation(self); - }) - .def_property_readonly( + .def_prop_ro("is_projected_permutation", + [](PyAffineMap &self) { + return mlirAffineMapIsProjectedPermutation(self); + }) + .def_prop_ro( "n_dims", [](PyAffineMap &self) { return mlirAffineMapGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyAffineMap &self) { return mlirAffineMapGetNumInputs(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyAffineMap &self) { return mlirAffineMapGetNumSymbols(self); }) - .def_property_readonly("results", [](PyAffineMap &self) { - return PyAffineMapExprList(self); - }); + .def_prop_ro("results", + [](PyAffineMap &self) { return PyAffineMapExprList(self); }); PyAffineMapExprList::bind(m); //---------------------------------------------------------------------------- // Mapping of PyIntegerSet. //---------------------------------------------------------------------------- - py::class_(m, "IntegerSet", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyIntegerSet::getCapsule) + nb::class_(m, "IntegerSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyIntegerSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyIntegerSet::createFromCapsule) .def("__eq__", [](PyIntegerSet &self, PyIntegerSet &other) { return self == other; }) - .def("__eq__", [](PyIntegerSet &self, py::object other) { return false; }) + .def("__eq__", [](PyIntegerSet &self, nb::object other) { return false; }) .def("__str__", [](PyIntegerSet &self) { PyPrintAccumulator printAccum; @@ -871,7 +864,7 @@ void mlir::python::populateIRAffine(py::module &m) { [](PyIntegerSet &self) { return static_cast(llvm::hash_value(self.get().ptr)); }) - .def_property_readonly( + .def_prop_ro( "context", [](PyIntegerSet &self) { return self.getContext().getObject(); }) .def( @@ -879,14 +872,14 @@ void mlir::python::populateIRAffine(py::module &m) { kDumpDocstring) .def_static( "get", - [](intptr_t numDims, intptr_t numSymbols, py::list exprs, + [](intptr_t numDims, intptr_t numSymbols, nb::list exprs, std::vector eqFlags, DefaultingPyMlirContext context) { if (exprs.size() != eqFlags.size()) - throw py::value_error( + throw nb::value_error( "Expected the number of constraints to match " "that of equality flags"); - if (exprs.empty()) - throw py::value_error("Expected non-empty list of constraints"); + if (exprs.size() == 0) + throw nb::value_error("Expected non-empty list of constraints"); // Copy over to a SmallVector because std::vector has a // specialization for booleans that packs data and does not @@ -901,8 +894,8 @@ void mlir::python::populateIRAffine(py::module &m) { affineExprs.data(), flags.data()); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), py::arg("exprs"), - py::arg("eq_flags"), py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), nb::arg("exprs"), + nb::arg("eq_flags"), nb::arg("context").none() = nb::none()) .def_static( "get_empty", [](intptr_t numDims, intptr_t numSymbols, @@ -911,20 +904,20 @@ void mlir::python::populateIRAffine(py::module &m) { mlirIntegerSetEmptyGet(context->get(), numDims, numSymbols); return PyIntegerSet(context->getRef(), set); }, - py::arg("num_dims"), py::arg("num_symbols"), - py::arg("context") = py::none()) + nb::arg("num_dims"), nb::arg("num_symbols"), + nb::arg("context").none() = nb::none()) .def( "get_replaced", - [](PyIntegerSet &self, py::list dimExprs, py::list symbolExprs, + [](PyIntegerSet &self, nb::list dimExprs, nb::list symbolExprs, intptr_t numResultDims, intptr_t numResultSymbols) { if (static_cast(dimExprs.size()) != mlirIntegerSetGetNumDims(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of dimension replacement expressions " "to match that of dimensions"); if (static_cast(symbolExprs.size()) != mlirIntegerSetGetNumSymbols(self)) - throw py::value_error( + throw nb::value_error( "Expected the number of symbol replacement expressions " "to match that of symbols"); @@ -940,30 +933,30 @@ void mlir::python::populateIRAffine(py::module &m) { numResultDims, numResultSymbols); return PyIntegerSet(self.getContext(), set); }, - py::arg("dim_exprs"), py::arg("symbol_exprs"), - py::arg("num_result_dims"), py::arg("num_result_symbols")) - .def_property_readonly("is_canonical_empty", - [](PyIntegerSet &self) { - return mlirIntegerSetIsCanonicalEmpty(self); - }) - .def_property_readonly( + nb::arg("dim_exprs"), nb::arg("symbol_exprs"), + nb::arg("num_result_dims"), nb::arg("num_result_symbols")) + .def_prop_ro("is_canonical_empty", + [](PyIntegerSet &self) { + return mlirIntegerSetIsCanonicalEmpty(self); + }) + .def_prop_ro( "n_dims", [](PyIntegerSet &self) { return mlirIntegerSetGetNumDims(self); }) - .def_property_readonly( + .def_prop_ro( "n_symbols", [](PyIntegerSet &self) { return mlirIntegerSetGetNumSymbols(self); }) - .def_property_readonly( + .def_prop_ro( "n_inputs", [](PyIntegerSet &self) { return mlirIntegerSetGetNumInputs(self); }) - .def_property_readonly("n_equalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumEqualities(self); - }) - .def_property_readonly("n_inequalities", - [](PyIntegerSet &self) { - return mlirIntegerSetGetNumInequalities(self); - }) - .def_property_readonly("constraints", [](PyIntegerSet &self) { + .def_prop_ro("n_equalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumEqualities(self); + }) + .def_prop_ro("n_inequalities", + [](PyIntegerSet &self) { + return mlirIntegerSetGetNumInequalities(self); + }) + .def_prop_ro("constraints", [](PyIntegerSet &self) { return PyIntegerSetConstraintList(self); }); PyIntegerSetConstraint::bind(m); diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index cc9532f4e..779af0950 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,23 +6,29 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include +#include +#include +#include + +#include #include +#include #include #include #include "IRModule.h" - -#include "PybindUtils.h" -#include - -#include "llvm/ADT/ScopeExit.h" -#include "llvm/Support/raw_ostream.h" - +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Support/raw_ostream.h" -namespace py = pybind11; +namespace nb = nanobind; +using namespace nanobind::literals; using namespace mlir; using namespace mlir::python; @@ -123,10 +129,119 @@ subsequent processing. namespace { +struct nb_buffer_info { + void *ptr = nullptr; + ssize_t itemsize = 0; + ssize_t size = 0; + const char *format = nullptr; + ssize_t ndim = 0; + SmallVector shape; + SmallVector strides; + bool readonly = false; + + nb_buffer_info( + void *ptr, ssize_t itemsize, const char *format, ssize_t ndim, + SmallVector shape_in, SmallVector strides_in, + bool readonly = false, + std::unique_ptr owned_view_in = + std::unique_ptr(nullptr, nullptr)) + : ptr(ptr), itemsize(itemsize), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)), + readonly(readonly), owned_view(std::move(owned_view_in)) { + size = 1; + for (ssize_t i = 0; i < ndim; ++i) { + size *= shape[i]; + } + } + + explicit nb_buffer_info(Py_buffer *view) + : nb_buffer_info(view->buf, view->itemsize, view->format, view->ndim, + {view->shape, view->shape + view->ndim}, + // TODO(phawkins): check for null strides + {view->strides, view->strides + view->ndim}, + view->readonly != 0, + std::unique_ptr( + view, PyBuffer_Release)) {} + + nb_buffer_info(const nb_buffer_info &) = delete; + nb_buffer_info(nb_buffer_info &&) = default; + nb_buffer_info &operator=(const nb_buffer_info &) = delete; + nb_buffer_info &operator=(nb_buffer_info &&) = default; + +private: + std::unique_ptr owned_view; +}; + +class nb_buffer : public nb::object { + NB_OBJECT_DEFAULT(nb_buffer, object, "buffer", PyObject_CheckBuffer); + + nb_buffer_info request() const { + int flags = PyBUF_STRIDES | PyBUF_FORMAT; + auto *view = new Py_buffer(); + if (PyObject_GetBuffer(ptr(), view, flags) != 0) { + delete view; + throw nb::python_error(); + } + return nb_buffer_info(view); + } +}; + +template +struct nb_format_descriptor {}; + +template <> +struct nb_format_descriptor { + static const char *format() { return "?"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "b"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "B"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "h"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "H"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "i"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "I"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "Q"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "f"; } +}; +template <> +struct nb_format_descriptor { + static const char *format() { return "d"; } +}; + static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + class PyAffineMapAttribute : public PyConcreteAttribute { public: static constexpr IsAFunctionTy isaFunction = mlirAttributeIsAAffineMap; @@ -142,9 +257,9 @@ class PyAffineMapAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirAffineMapAttrGet(affineMap.get()); return PyAffineMapAttribute(affineMap.getContext(), attr); }, - py::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); - c.def_property_readonly("value", mlirAffineMapAttrGetValue, - "Returns the value of the AffineMap attribute"); + nb::arg("affine_map"), "Gets an attribute wrapping an AffineMap."); + c.def_prop_ro("value", mlirAffineMapAttrGetValue, + "Returns the value of the AffineMap attribute"); } }; @@ -164,25 +279,24 @@ class PyIntegerSetAttribute MlirAttribute attr = mlirIntegerSetAttrGet(integerSet.get()); return PyIntegerSetAttribute(integerSet.getContext(), attr); }, - py::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); + nb::arg("integer_set"), "Gets an attribute wrapping an IntegerSet."); } }; template -static T pyTryCast(py::handle object) { +static T pyTryCast(nb::handle object) { try { - return object.cast(); - } catch (py::cast_error &err) { - std::string msg = - std::string( - "Invalid attribute when attempting to create an ArrayAttribute (") + - err.what() + ")"; - throw py::cast_error(msg); - } catch (py::reference_cast_error &err) { + return nb::cast(object); + } catch (nb::cast_error &err) { + std::string msg = std::string("Invalid attribute when attempting to " + "create an ArrayAttribute (") + + err.what() + ")"; + throw std::runtime_error(msg.c_str()); + } catch (std::runtime_error &err) { std::string msg = std::string("Invalid attribute (None?) when attempting " "to create an ArrayAttribute (") + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg.c_str()); } } @@ -205,14 +319,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { EltTy dunderNext() { // Throw if the index has reached the end. if (nextIndex >= mlirDenseArrayGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return DerivedT::getElement(attr.get(), nextIndex++); } /// Bind the iterator class. - static void bind(py::module &m) { - py::class_(m, DerivedT::pyIteratorName, - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, DerivedT::pyIteratorName) .def("__iter__", &PyDenseArrayIterator::dunderIter) .def("__next__", &PyDenseArrayIterator::dunderNext); } @@ -230,17 +343,35 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { /// Bind the attribute class. static void bindDerived(typename PyConcreteAttribute::ClassTy &c) { // Bind the constructor. - c.def_static( - "get", - [](const std::vector &values, DefaultingPyMlirContext ctx) { - return getAttribute(values, ctx->getRef()); - }, - py::arg("values"), py::arg("context") = py::none(), - "Gets a uniqued dense array attribute"); + if constexpr (std::is_same_v) { + c.def_static( + "get", + [](const nb::sequence &py_values, DefaultingPyMlirContext ctx) { + std::vector values; + for (nb::handle py_value : py_values) { + int is_true = PyObject_IsTrue(py_value.ptr()); + if (is_true < 0) { + throw nb::python_error(); + } + values.push_back(is_true); + } + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } else { + c.def_static( + "get", + [](const std::vector &values, DefaultingPyMlirContext ctx) { + return getAttribute(values, ctx->getRef()); + }, + nb::arg("values"), nb::arg("context").none() = nb::none(), + "Gets a uniqued dense array attribute"); + } // Bind the array methods. c.def("__getitem__", [](DerivedT &arr, intptr_t i) { if (i >= mlirDenseArrayGetNumElements(arr)) - throw py::index_error("DenseArray index out of range"); + throw nb::index_error("DenseArray index out of range"); return arr.getItem(i); }); c.def("__len__", [](const DerivedT &arr) { @@ -248,13 +379,13 @@ class PyDenseArrayAttribute : public PyConcreteAttribute { }); c.def("__iter__", [](const DerivedT &arr) { return PyDenseArrayIterator(arr); }); - c.def("__add__", [](DerivedT &arr, const py::list &extras) { + c.def("__add__", [](DerivedT &arr, const nb::list &extras) { std::vector values; intptr_t numOldElements = mlirDenseArrayGetNumElements(arr); - values.reserve(numOldElements + py::len(extras)); + values.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) values.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) values.push_back(pyTryCast(attr)); return getAttribute(values, arr.getContext()); }); @@ -358,13 +489,12 @@ class PyArrayAttribute : public PyConcreteAttribute { MlirAttribute dunderNext() { // TODO: Throw is an inefficient way to stop iteration. if (nextIndex >= mlirArrayAttrGetNumElements(attr.get())) - throw py::stop_iteration(); + throw nb::stop_iteration(); return mlirArrayAttrGetElement(attr.get(), nextIndex++); } - static void bind(py::module &m) { - py::class_(m, "ArrayAttributeIterator", - py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "ArrayAttributeIterator") .def("__iter__", &PyArrayAttributeIterator::dunderIter) .def("__next__", &PyArrayAttributeIterator::dunderNext); } @@ -381,9 +511,9 @@ class PyArrayAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](py::list attributes, DefaultingPyMlirContext context) { + [](nb::list attributes, DefaultingPyMlirContext context) { SmallVector mlirAttributes; - mlirAttributes.reserve(py::len(attributes)); + mlirAttributes.reserve(nb::len(attributes)); for (auto attribute : attributes) { mlirAttributes.push_back(pyTryCast(attribute)); } @@ -391,12 +521,12 @@ class PyArrayAttribute : public PyConcreteAttribute { context->get(), mlirAttributes.size(), mlirAttributes.data()); return PyArrayAttribute(context->getRef(), attr); }, - py::arg("attributes"), py::arg("context") = py::none(), + nb::arg("attributes"), nb::arg("context").none() = nb::none(), "Gets a uniqued Array attribute"); c.def("__getitem__", [](PyArrayAttribute &arr, intptr_t i) { if (i >= mlirArrayAttrGetNumElements(arr)) - throw py::index_error("ArrayAttribute index out of range"); + throw nb::index_error("ArrayAttribute index out of range"); return arr.getItem(i); }) .def("__len__", @@ -406,13 +536,13 @@ class PyArrayAttribute : public PyConcreteAttribute { .def("__iter__", [](const PyArrayAttribute &arr) { return PyArrayAttributeIterator(arr); }); - c.def("__add__", [](PyArrayAttribute arr, py::list extras) { + c.def("__add__", [](PyArrayAttribute arr, nb::list extras) { std::vector attributes; intptr_t numOldElements = mlirArrayAttrGetNumElements(arr); - attributes.reserve(numOldElements + py::len(extras)); + attributes.reserve(numOldElements + nb::len(extras)); for (intptr_t i = 0; i < numOldElements; ++i) attributes.push_back(arr.getItem(i)); - for (py::handle attr : extras) + for (nb::handle attr : extras) attributes.push_back(pyTryCast(attr)); MlirAttribute arrayAttr = mlirArrayAttrGet( arr.getContext()->get(), attributes.size(), attributes.data()); @@ -440,7 +570,7 @@ class PyFloatAttribute : public PyConcreteAttribute { throw MLIRError("Invalid attribute", errors.take()); return PyFloatAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), py::arg("loc") = py::none(), + nb::arg("type"), nb::arg("value"), nb::arg("loc").none() = nb::none(), "Gets an uniqued float point attribute associated to a type"); c.def_static( "get_f32", @@ -449,7 +579,7 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF32TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f32 type"); c.def_static( "get_f64", @@ -458,10 +588,10 @@ class PyFloatAttribute : public PyConcreteAttribute { context->get(), mlirF64TypeGet(context->get()), value); return PyFloatAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued float point attribute associated to a f64 type"); - c.def_property_readonly("value", mlirFloatAttrGetValueDouble, - "Returns the value of the float attribute"); + c.def_prop_ro("value", mlirFloatAttrGetValueDouble, + "Returns the value of the float attribute"); c.def("__float__", mlirFloatAttrGetValueDouble, "Converts the value of the float attribute to a Python float"); } @@ -481,20 +611,20 @@ class PyIntegerAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirIntegerAttrGet(type, value); return PyIntegerAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets an uniqued integer attribute associated to a type"); - c.def_property_readonly("value", toPyInt, - "Returns the value of the integer attribute"); + c.def_prop_ro("value", toPyInt, + "Returns the value of the integer attribute"); c.def("__int__", toPyInt, "Converts the value of the integer attribute to a Python int"); - c.def_property_readonly_static("static_typeid", - [](py::object & /*class*/) -> MlirTypeID { - return mlirIntegerAttrGetTypeID(); - }); + c.def_prop_ro_static("static_typeid", + [](nb::object & /*class*/) -> MlirTypeID { + return mlirIntegerAttrGetTypeID(); + }); } private: - static py::int_ toPyInt(PyIntegerAttribute &self) { + static int64_t toPyInt(PyIntegerAttribute &self) { MlirType type = mlirAttributeGetType(self); if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type)) return mlirIntegerAttrGetValueInt(self); @@ -518,10 +648,10 @@ class PyBoolAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirBoolAttrGet(context->get(), value); return PyBoolAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets an uniqued bool attribute"); - c.def_property_readonly("value", mlirBoolAttrGetValue, - "Returns the value of the bool attribute"); + c.def_prop_ro("value", mlirBoolAttrGetValue, + "Returns the value of the bool attribute"); c.def("__bool__", mlirBoolAttrGetValue, "Converts the value of the bool attribute to a Python bool"); } @@ -555,9 +685,9 @@ class PySymbolRefAttribute : public PyConcreteAttribute { DefaultingPyMlirContext context) { return PySymbolRefAttribute::fromList(symbols, context.resolve()); }, - py::arg("symbols"), py::arg("context") = py::none(), + nb::arg("symbols"), nb::arg("context").none() = nb::none(), "Gets a uniqued SymbolRef attribute from a list of symbol names"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PySymbolRefAttribute &self) { std::vector symbols = { @@ -589,13 +719,13 @@ class PyFlatSymbolRefAttribute mlirFlatSymbolRefAttrGet(context->get(), toMlirStringRef(value)); return PyFlatSymbolRefAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued FlatSymbolRef attribute"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyFlatSymbolRefAttribute &self) { MlirStringRef stringRef = mlirFlatSymbolRefAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the FlatSymbolRef attribute as a string"); } @@ -612,29 +742,29 @@ class PyOpaqueAttribute : public PyConcreteAttribute { static void bindDerived(ClassTy &c) { c.def_static( "get", - [](std::string dialectNamespace, py::buffer buffer, PyType &type, + [](std::string dialectNamespace, nb_buffer buffer, PyType &type, DefaultingPyMlirContext context) { - const py::buffer_info bufferInfo = buffer.request(); + const nb_buffer_info bufferInfo = buffer.request(); intptr_t bufferSize = bufferInfo.size; MlirAttribute attr = mlirOpaqueAttrGet( context->get(), toMlirStringRef(dialectNamespace), bufferSize, static_cast(bufferInfo.ptr), type); return PyOpaqueAttribute(context->getRef(), attr); }, - py::arg("dialect_namespace"), py::arg("buffer"), py::arg("type"), - py::arg("context") = py::none(), "Gets an Opaque attribute."); - c.def_property_readonly( + nb::arg("dialect_namespace"), nb::arg("buffer"), nb::arg("type"), + nb::arg("context").none() = nb::none(), "Gets an Opaque attribute."); + c.def_prop_ro( "dialect_namespace", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque attribute as a string"); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueAttribute &self) { MlirStringRef stringRef = mlirOpaqueAttrGetData(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the data for the Opaqued attributes as `bytes`"); } @@ -656,7 +786,16 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrGet(context->get(), toMlirStringRef(value)); return PyStringAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), + "Gets a uniqued string attribute"); + c.def_static( + "get", + [](nb::bytes value, DefaultingPyMlirContext context) { + MlirAttribute attr = + mlirStringAttrGet(context->get(), toMlirStringRef(value)); + return PyStringAttribute(context->getRef(), attr); + }, + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued string attribute"); c.def_static( "get_typed", @@ -665,20 +804,20 @@ class PyStringAttribute : public PyConcreteAttribute { mlirStringAttrTypedGet(type, toMlirStringRef(value)); return PyStringAttribute(type.getContext(), attr); }, - py::arg("type"), py::arg("value"), + nb::arg("type"), nb::arg("value"), "Gets a uniqued string attribute associated to a type"); - c.def_property_readonly( + c.def_prop_ro( "value", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the value of the string attribute"); - c.def_property_readonly( + c.def_prop_ro( "value_bytes", [](PyStringAttribute &self) { MlirStringRef stringRef = mlirStringAttrGetValue(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }, "Returns the value of the string attribute as `bytes`"); } @@ -693,12 +832,11 @@ class PyDenseElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseElementsAttribute - getFromList(py::list attributes, std::optional explicitType, + getFromList(nb::list attributes, std::optional explicitType, DefaultingPyMlirContext contextWrapper) { - - const size_t numAttributes = py::len(attributes); + const size_t numAttributes = nb::len(attributes); if (numAttributes == 0) - throw py::value_error("Attributes list must be non-empty."); + throw nb::value_error("Attributes list must be non-empty."); MlirType shapedType; if (explicitType) { @@ -708,8 +846,8 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "Expected a static ShapedType for the shaped_type parameter: " - << py::repr(py::cast(*explicitType)); - throw py::value_error(message); + << nb::cast(nb::repr(nb::cast(*explicitType))); + throw nb::value_error(message.c_str()); } shapedType = *explicitType; } else { @@ -722,7 +860,7 @@ class PyDenseElementsAttribute SmallVector mlirAttributes; mlirAttributes.reserve(numAttributes); - for (const py::handle &attribute : attributes) { + for (const nb::handle &attribute : attributes) { MlirAttribute mlirAttribute = pyTryCast(attribute); MlirType attrType = mlirAttributeGetType(mlirAttribute); mlirAttributes.push_back(mlirAttribute); @@ -731,9 +869,11 @@ class PyDenseElementsAttribute std::string message; llvm::raw_string_ostream os(message); os << "All attributes must be of the same type and match " - << "the type parameter: expected=" << py::repr(py::cast(shapedType)) - << ", but got=" << py::repr(py::cast(attrType)); - throw py::value_error(message); + << "the type parameter: expected=" + << nb::cast(nb::repr(nb::cast(shapedType))) + << ", but got=" + << nb::cast(nb::repr(nb::cast(attrType))); + throw nb::value_error(message.c_str()); } } @@ -744,7 +884,7 @@ class PyDenseElementsAttribute } static PyDenseElementsAttribute - getFromBuffer(py::buffer array, bool signless, + getFromBuffer(nb_buffer array, bool signless, std::optional explicitType, std::optional> explicitShape, DefaultingPyMlirContext contextWrapper) { @@ -755,7 +895,7 @@ class PyDenseElementsAttribute } Py_buffer view; if (PyObject_GetBuffer(array.ptr(), &view, flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } auto freeBuffer = llvm::make_scope_exit([&]() { PyBuffer_Release(&view); }); @@ -778,25 +918,25 @@ class PyDenseElementsAttribute if (!mlirAttributeIsAInteger(elementAttr) && !mlirAttributeIsAFloat(elementAttr)) { std::string message = "Illegal element type for DenseElementsAttr: "; - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } if (!mlirTypeIsAShaped(shapedType) || !mlirShapedTypeHasStaticShape(shapedType)) { std::string message = "Expected a static ShapedType for the shaped_type parameter: "; - message.append(py::repr(py::cast(shapedType))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); + throw nb::value_error(message.c_str()); } MlirType shapedElementType = mlirShapedTypeGetElementType(shapedType); MlirType attrType = mlirAttributeGetType(elementAttr); if (!mlirTypeEqual(shapedElementType, attrType)) { std::string message = "Shaped element type and attribute type must be equal: shaped="; - message.append(py::repr(py::cast(shapedType))); + message.append(nb::cast(nb::repr(nb::cast(shapedType)))); message.append(", element="); - message.append(py::repr(py::cast(elementAttr))); - throw py::value_error(message); + message.append(nb::cast(nb::repr(nb::cast(elementAttr)))); + throw nb::value_error(message.c_str()); } MlirAttribute elements = @@ -806,7 +946,7 @@ class PyDenseElementsAttribute intptr_t dunderLen() { return mlirElementsAttrGetNumElements(*this); } - py::buffer_info accessBuffer() { + std::unique_ptr accessBuffer() { MlirType shapedType = mlirAttributeGetType(*this); MlirType elementType = mlirShapedTypeGetElementType(shapedType); std::string format; @@ -887,34 +1027,44 @@ class PyDenseElementsAttribute } static void bindDerived(ClassTy &c) { +#if PY_VERSION_HEX < 0x03090000 + PyTypeObject *tp = reinterpret_cast(c.ptr()); + tp->tp_as_buffer->bf_getbuffer = PyDenseElementsAttribute::bf_getbuffer; + tp->tp_as_buffer->bf_releasebuffer = + PyDenseElementsAttribute::bf_releasebuffer; +#endif c.def("__len__", &PyDenseElementsAttribute::dunderLen) .def_static("get", PyDenseElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("signless") = true, - py::arg("type") = py::none(), py::arg("shape") = py::none(), - py::arg("context") = py::none(), + nb::arg("array"), nb::arg("signless") = true, + nb::arg("type").none() = nb::none(), + nb::arg("shape").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetDocstring) .def_static("get", PyDenseElementsAttribute::getFromList, - py::arg("attrs"), py::arg("type") = py::none(), - py::arg("context") = py::none(), + nb::arg("attrs"), nb::arg("type").none() = nb::none(), + nb::arg("context").none() = nb::none(), kDenseElementsAttrGetFromListDocstring) .def_static("get_splat", PyDenseElementsAttribute::getSplat, - py::arg("shaped_type"), py::arg("element_attr"), + nb::arg("shaped_type"), nb::arg("element_attr"), "Gets a DenseElementsAttr where all values are the same") - .def_property_readonly("is_splat", - [](PyDenseElementsAttribute &self) -> bool { - return mlirDenseElementsAttrIsSplat(self); - }) - .def("get_splat_value", - [](PyDenseElementsAttribute &self) { - if (!mlirDenseElementsAttrIsSplat(self)) - throw py::value_error( - "get_splat_value called on a non-splat attribute"); - return mlirDenseElementsAttrGetSplatValue(self); - }) - .def_buffer(&PyDenseElementsAttribute::accessBuffer); + .def_prop_ro("is_splat", + [](PyDenseElementsAttribute &self) -> bool { + return mlirDenseElementsAttrIsSplat(self); + }) + .def("get_splat_value", [](PyDenseElementsAttribute &self) { + if (!mlirDenseElementsAttrIsSplat(self)) + throw nb::value_error( + "get_splat_value called on a non-splat attribute"); + return mlirDenseElementsAttrGetSplatValue(self); + }); } + static PyType_Slot slots[]; + private: + static int bf_getbuffer(PyObject *exporter, Py_buffer *view, int flags); + static void bf_releasebuffer(PyObject *, Py_buffer *buffer); + static bool isUnsignedIntegerFormat(std::string_view format) { if (format.empty()) return false; @@ -1039,27 +1189,27 @@ class PyDenseElementsAttribute return mlirDenseElementsAttrRawBufferGet(type, view.len, view.buf); } - // There is a complication for boolean numpy arrays, as numpy represents them - // as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 booleans - // per byte. + // There is a complication for boolean numpy arrays, as numpy represents + // them as 8 bits (1 byte) per boolean, whereas MLIR bitpacks them into 8 + // booleans per byte. static MlirAttribute getBitpackedAttributeFromBooleanBuffer( Py_buffer &view, std::optional> explicitShape, MlirContext &context) { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a bit-packed MLIR attribute is " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a bit-packed MLIR attribute is " "unsupported on big-endian systems"); } + nb::ndarray, nb::c_contig> unpackedArray( + /*data=*/static_cast(view.buf), + /*shape=*/{static_cast(view.len)}); - py::array_t unpackedArray(view.len, - static_cast(view.buf)); - - py::module numpy = py::module::import("numpy"); - py::object packbitsFunc = numpy.attr("packbits"); - py::object packedBooleans = - packbitsFunc(unpackedArray, "bitorder"_a = "little"); - py::buffer_info pythonBuffer = packedBooleans.cast().request(); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object packbitsFunc = numpy.attr("packbits"); + nb::object packedBooleans = + packbitsFunc(nb::cast(unpackedArray), "bitorder"_a = "little"); + nb_buffer_info pythonBuffer = nb::cast(packedBooleans).request(); MlirType bitpackedType = getShapedType(mlirIntegerTypeGet(context, 1), explicitShape, view); @@ -1073,11 +1223,11 @@ class PyDenseElementsAttribute // This does the opposite transformation of // `getBitpackedAttributeFromBooleanBuffer` - py::buffer_info getBooleanBufferFromBitpackedAttribute() { + std::unique_ptr getBooleanBufferFromBitpackedAttribute() { if (llvm::endianness::native != llvm::endianness::little) { - // Given we have no good way of testing the behavior on big-endian systems - // we will throw - throw py::type_error("Constructing a numpy array from a MLIR attribute " + // Given we have no good way of testing the behavior on big-endian + // systems we will throw + throw nb::type_error("Constructing a numpy array from a MLIR attribute " "is unsupported on big-endian systems"); } @@ -1085,21 +1235,24 @@ class PyDenseElementsAttribute int64_t numBitpackedBytes = llvm::divideCeil(numBooleans, 8); uint8_t *bitpackedData = static_cast( const_cast(mlirDenseElementsAttrGetRawData(*this))); - py::array_t packedArray(numBitpackedBytes, bitpackedData); + nb::ndarray, nb::c_contig> packedArray( + /*data=*/bitpackedData, + /*shape=*/{static_cast(numBitpackedBytes)}); - py::module numpy = py::module::import("numpy"); - py::object unpackbitsFunc = numpy.attr("unpackbits"); - py::object equalFunc = numpy.attr("equal"); - py::object reshapeFunc = numpy.attr("reshape"); - py::array unpackedBooleans = - unpackbitsFunc(packedArray, "bitorder"_a = "little"); + nb::module_ numpy = nb::module_::import_("numpy"); + nb::object unpackbitsFunc = numpy.attr("unpackbits"); + nb::object equalFunc = numpy.attr("equal"); + nb::object reshapeFunc = numpy.attr("reshape"); + nb::object unpackedBooleans = + unpackbitsFunc(nb::cast(packedArray), "bitorder"_a = "little"); // Unpackbits operates on bytes and gives back a flat 0 / 1 integer array. // We need to: // 1. Slice away the padded bits // 2. Make the boolean array have the correct shape // 3. Convert the array to a boolean array - unpackedBooleans = unpackedBooleans[py::slice(0, numBooleans, 1)]; + unpackedBooleans = unpackedBooleans[nb::slice( + nb::int_(0), nb::int_(numBooleans), nb::int_(1))]; unpackedBooleans = equalFunc(unpackedBooleans, 1); MlirType shapedType = mlirAttributeGetType(*this); @@ -1110,15 +1263,15 @@ class PyDenseElementsAttribute } unpackedBooleans = reshapeFunc(unpackedBooleans, shape); - // Make sure the returned py::buffer_view claims ownership of the data in + // Make sure the returned nb::buffer_view claims ownership of the data in // `pythonBuffer` so it remains valid when Python reads it - py::buffer pythonBuffer = unpackedBooleans.cast(); - return pythonBuffer.request(); + nb_buffer pythonBuffer = nb::cast(unpackedBooleans); + return std::make_unique(pythonBuffer.request()); } template - py::buffer_info bufferInfo(MlirType shapedType, - const char *explicitFormat = nullptr) { + std::unique_ptr + bufferInfo(MlirType shapedType, const char *explicitFormat = nullptr) { intptr_t rank = mlirShapedTypeGetRank(shapedType); // Prepare the data for the buffer_info. // Buffer is configured for read-only access below. @@ -1142,19 +1295,72 @@ class PyDenseElementsAttribute } strides.push_back(sizeof(Type)); } - std::string format; + const char *format; if (explicitFormat) { format = explicitFormat; } else { - format = py::format_descriptor::format(); + format = nb_format_descriptor::format(); } - return py::buffer_info(data, sizeof(Type), format, rank, shape, strides, - /*readonly=*/true); + return std::make_unique( + data, sizeof(Type), format, rank, std::move(shape), std::move(strides), + /*readonly=*/true); } }; // namespace -/// Refinement of the PyDenseElementsAttribute for attributes containing integer -/// (and boolean) values. Supports element access. +PyType_Slot PyDenseElementsAttribute::slots[] = { +// Python 3.8 doesn't allow setting the buffer protocol slots from a type spec. +#if PY_VERSION_HEX >= 0x03090000 + {Py_bf_getbuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_getbuffer)}, + {Py_bf_releasebuffer, + reinterpret_cast(PyDenseElementsAttribute::bf_releasebuffer)}, +#endif + {0, nullptr}, +}; + +/*static*/ int PyDenseElementsAttribute::bf_getbuffer(PyObject *obj, + Py_buffer *view, + int flags) { + view->obj = nullptr; + std::unique_ptr info; + try { + auto *attr = nb::cast(nb::handle(obj)); + info = attr->accessBuffer(); + } catch (nb::python_error &e) { + e.restore(); + nb::chain_error(PyExc_BufferError, "Error converting attribute to buffer"); + return -1; + } + view->obj = obj; + view->ndim = 1; + view->buf = info->ptr; + view->itemsize = info->itemsize; + view->len = info->itemsize; + for (auto s : info->shape) { + view->len *= s; + } + view->readonly = info->readonly; + if ((flags & PyBUF_FORMAT) == PyBUF_FORMAT) { + view->format = const_cast(info->format); + } + if ((flags & PyBUF_STRIDES) == PyBUF_STRIDES) { + view->ndim = static_cast(info->ndim); + view->strides = info->strides.data(); + view->shape = info->shape.data(); + } + view->suboffsets = nullptr; + view->internal = info.release(); + Py_INCREF(obj); + return 0; +} + +/*static*/ void PyDenseElementsAttribute::bf_releasebuffer(PyObject *, + Py_buffer *view) { + delete reinterpret_cast(view->internal); +} + +/// Refinement of the PyDenseElementsAttribute for attributes containing +/// integer (and boolean) values. Supports element access. class PyDenseIntElementsAttribute : public PyConcreteAttribute { @@ -1163,11 +1369,11 @@ class PyDenseIntElementsAttribute static constexpr const char *pyClassName = "DenseIntElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - /// Returns the element at the given linear position. Asserts if the index is - /// out of range. - py::int_ dunderGetItem(intptr_t pos) { + /// Returns the element at the given linear position. Asserts if the index + /// is out of range. + nb::object dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); @@ -1175,7 +1381,7 @@ class PyDenseIntElementsAttribute assert(mlirTypeIsAInteger(type) && "expected integer element type in dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::int_ is implicitly constructible + // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. @@ -1183,38 +1389,38 @@ class PyDenseIntElementsAttribute bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetUInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetUInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetUInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetUInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetUInt64Value(*this, pos)); } } else { if (width == 1) { - return mlirDenseElementsAttrGetBoolValue(*this, pos); + return nb::int_(int(mlirDenseElementsAttrGetBoolValue(*this, pos))); } if (width == 8) { - return mlirDenseElementsAttrGetInt8Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt8Value(*this, pos)); } if (width == 16) { - return mlirDenseElementsAttrGetInt16Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt16Value(*this, pos)); } if (width == 32) { - return mlirDenseElementsAttrGetInt32Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt32Value(*this, pos)); } if (width == 64) { - return mlirDenseElementsAttrGetInt64Value(*this, pos); + return nb::int_(mlirDenseElementsAttrGetInt64Value(*this, pos)); } } - throw py::type_error("Unsupported integer type"); + throw nb::type_error("Unsupported integer type"); } static void bindDerived(ClassTy &c) { @@ -1231,7 +1437,7 @@ class PyDenseResourceElementsAttribute using PyConcreteAttribute::PyConcreteAttribute; static PyDenseResourceElementsAttribute - getFromBuffer(py::buffer buffer, const std::string &name, const PyType &type, + getFromBuffer(nb_buffer buffer, const std::string &name, const PyType &type, std::optional alignment, bool isMutable, DefaultingPyMlirContext contextWrapper) { if (!mlirTypeIsAShaped(type)) { @@ -1244,7 +1450,7 @@ class PyDenseResourceElementsAttribute int flags = PyBUF_STRIDES; std::unique_ptr view = std::make_unique(); if (PyObject_GetBuffer(buffer.ptr(), view.get(), flags) != 0) { - throw py::error_already_set(); + throw nb::python_error(); } // This scope releaser will only release if we haven't yet transferred @@ -1289,12 +1495,12 @@ class PyDenseResourceElementsAttribute } static void bindDerived(ClassTy &c) { - c.def_static("get_from_buffer", - PyDenseResourceElementsAttribute::getFromBuffer, - py::arg("array"), py::arg("name"), py::arg("type"), - py::arg("alignment") = py::none(), - py::arg("is_mutable") = false, py::arg("context") = py::none(), - kDenseResourceElementsAttrGetFromBufferDocstring); + c.def_static( + "get_from_buffer", PyDenseResourceElementsAttribute::getFromBuffer, + nb::arg("array"), nb::arg("name"), nb::arg("type"), + nb::arg("alignment").none() = nb::none(), nb::arg("is_mutable") = false, + nb::arg("context").none() = nb::none(), + kDenseResourceElementsAttrGetFromBufferDocstring); } }; @@ -1318,12 +1524,12 @@ class PyDictAttribute : public PyConcreteAttribute { c.def("__len__", &PyDictAttribute::dunderLen); c.def_static( "get", - [](py::dict attributes, DefaultingPyMlirContext context) { + [](nb::dict attributes, DefaultingPyMlirContext context) { SmallVector mlirNamedAttributes; mlirNamedAttributes.reserve(attributes.size()); - for (auto &it : attributes) { - auto &mlirAttr = it.second.cast(); - auto name = it.first.cast(); + for (std::pair it : attributes) { + auto &mlirAttr = nb::cast(it.second); + auto name = nb::cast(it.first); mlirNamedAttributes.push_back(mlirNamedAttributeGet( mlirIdentifierGet(mlirAttributeGetContext(mlirAttr), toMlirStringRef(name)), @@ -1334,18 +1540,18 @@ class PyDictAttribute : public PyConcreteAttribute { mlirNamedAttributes.data()); return PyDictAttribute(context->getRef(), attr); }, - py::arg("value") = py::dict(), py::arg("context") = py::none(), + nb::arg("value") = nb::dict(), nb::arg("context").none() = nb::none(), "Gets an uniqued dict attribute"); c.def("__getitem__", [](PyDictAttribute &self, const std::string &name) { MlirAttribute attr = mlirDictionaryAttrGetElementByName(self, toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); return attr; }); c.def("__getitem__", [](PyDictAttribute &self, intptr_t index) { if (index < 0 || index >= self.dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirDictionaryAttrGetElement(self, index); return PyNamedAttribute( @@ -1365,25 +1571,25 @@ class PyDenseFPElementsAttribute static constexpr const char *pyClassName = "DenseFPElementsAttr"; using PyConcreteAttribute::PyConcreteAttribute; - py::float_ dunderGetItem(intptr_t pos) { + nb::float_ dunderGetItem(intptr_t pos) { if (pos < 0 || pos >= dunderLen()) { - throw py::index_error("attempt to access out of bounds element"); + throw nb::index_error("attempt to access out of bounds element"); } MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); // Dispatch element extraction to an appropriate C function based on the - // elemental type of the attribute. py::float_ is implicitly constructible + // elemental type of the attribute. nb::float_ is implicitly constructible // from float and double. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. if (mlirTypeIsAF32(type)) { - return mlirDenseElementsAttrGetFloatValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetFloatValue(*this, pos)); } if (mlirTypeIsAF64(type)) { - return mlirDenseElementsAttrGetDoubleValue(*this, pos); + return nb::float_(mlirDenseElementsAttrGetDoubleValue(*this, pos)); } - throw py::type_error("Unsupported floating-point type"); + throw nb::type_error("Unsupported floating-point type"); } static void bindDerived(ClassTy &c) { @@ -1406,9 +1612,9 @@ class PyTypeAttribute : public PyConcreteAttribute { MlirAttribute attr = mlirTypeAttrGet(value.get()); return PyTypeAttribute(context->getRef(), attr); }, - py::arg("value"), py::arg("context") = py::none(), + nb::arg("value"), nb::arg("context").none() = nb::none(), "Gets a uniqued Type attribute"); - c.def_property_readonly("value", [](PyTypeAttribute &self) { + c.def_prop_ro("value", [](PyTypeAttribute &self) { return mlirTypeAttrGetValue(self.get()); }); } @@ -1430,7 +1636,7 @@ class PyUnitAttribute : public PyConcreteAttribute { return PyUnitAttribute(context->getRef(), mlirUnitAttrGet(context->get())); }, - py::arg("context") = py::none(), "Create a Unit attribute."); + nb::arg("context").none() = nb::none(), "Create a Unit attribute."); } }; @@ -1453,7 +1659,8 @@ class PyStridedLayoutAttribute ctx->get(), offset, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("offset"), py::arg("strides"), py::arg("context") = py::none(), + nb::arg("offset"), nb::arg("strides"), + nb::arg("context").none() = nb::none(), "Gets a strided layout attribute."); c.def_static( "get_fully_dynamic", @@ -1465,16 +1672,17 @@ class PyStridedLayoutAttribute ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); }, - py::arg("rank"), py::arg("context") = py::none(), - "Gets a strided layout attribute with dynamic offset and strides of a " + nb::arg("rank"), nb::arg("context").none() = nb::none(), + "Gets a strided layout attribute with dynamic offset and strides of " + "a " "given rank."); - c.def_property_readonly( + c.def_prop_ro( "offset", [](PyStridedLayoutAttribute &self) { return mlirStridedLayoutAttrGetOffset(self); }, "Returns the value of the float point attribute"); - c.def_property_readonly( + c.def_prop_ro( "strides", [](PyStridedLayoutAttribute &self) { intptr_t size = mlirStridedLayoutAttrGetNumStrides(self); @@ -1488,63 +1696,64 @@ class PyStridedLayoutAttribute } }; -py::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseArrayAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseBoolArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseBoolArrayAttribute(pyAttribute)); + return nb::cast(PyDenseBoolArrayAttribute(pyAttribute)); if (PyDenseI8ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI8ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI8ArrayAttribute(pyAttribute)); if (PyDenseI16ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI16ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI16ArrayAttribute(pyAttribute)); if (PyDenseI32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI32ArrayAttribute(pyAttribute)); if (PyDenseI64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseI64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseI64ArrayAttribute(pyAttribute)); if (PyDenseF32ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF32ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF32ArrayAttribute(pyAttribute)); if (PyDenseF64ArrayAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseF64ArrayAttribute(pyAttribute)); + return nb::cast(PyDenseF64ArrayAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { +nb::object denseIntOrFPElementsAttributeCaster(PyAttribute &pyAttribute) { if (PyDenseFPElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseFPElementsAttribute(pyAttribute)); + return nb::cast(PyDenseFPElementsAttribute(pyAttribute)); if (PyDenseIntElementsAttribute::isaFunction(pyAttribute)) - return py::cast(PyDenseIntElementsAttribute(pyAttribute)); + return nb::cast(PyDenseIntElementsAttribute(pyAttribute)); std::string msg = std::string( "Can't cast unknown element type DenseIntOrFPElementsAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { +nb::object integerOrBoolAttributeCaster(PyAttribute &pyAttribute) { if (PyBoolAttribute::isaFunction(pyAttribute)) - return py::cast(PyBoolAttribute(pyAttribute)); + return nb::cast(PyBoolAttribute(pyAttribute)); if (PyIntegerAttribute::isaFunction(pyAttribute)) - return py::cast(PyIntegerAttribute(pyAttribute)); + return nb::cast(PyIntegerAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown element type DenseArrayAttr (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + ")"; + throw nb::type_error(msg.c_str()); } -py::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { +nb::object symbolRefOrFlatSymbolRefAttributeCaster(PyAttribute &pyAttribute) { if (PyFlatSymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PyFlatSymbolRefAttribute(pyAttribute)); + return nb::cast(PyFlatSymbolRefAttribute(pyAttribute)); if (PySymbolRefAttribute::isaFunction(pyAttribute)) - return py::cast(PySymbolRefAttribute(pyAttribute)); + return nb::cast(PySymbolRefAttribute(pyAttribute)); std::string msg = std::string("Can't cast unknown SymbolRef attribute (") + - std::string(py::repr(py::cast(pyAttribute))) + ")"; - throw py::cast_error(msg); + nb::cast(nb::repr(nb::cast(pyAttribute))) + + ")"; + throw nb::type_error(msg.c_str()); } } // namespace -void mlir::python::populateIRAttributes(py::module &m) { +void mlir::python::populateIRAttributes(nb::module_ &m) { PyAffineMapAttribute::bind(m); PyDenseBoolArrayAttribute::bind(m); PyDenseBoolArrayAttribute::PyDenseArrayIterator::bind(m); @@ -1562,24 +1771,26 @@ void mlir::python::populateIRAttributes(py::module &m) { PyDenseF64ArrayAttribute::PyDenseArrayIterator::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseArrayAttrGetTypeID(), - pybind11::cpp_function(denseArrayAttributeCaster)); + nb::cast(nb::cpp_function(denseArrayAttributeCaster))); PyArrayAttribute::bind(m); PyArrayAttribute::PyArrayAttributeIterator::bind(m); PyBoolAttribute::bind(m); - PyDenseElementsAttribute::bind(m); + PyDenseElementsAttribute::bind(m, PyDenseElementsAttribute::slots); PyDenseFPElementsAttribute::bind(m); PyDenseIntElementsAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirDenseIntOrFPElementsAttrGetTypeID(), - pybind11::cpp_function(denseIntOrFPElementsAttributeCaster)); + nb::cast( + nb::cpp_function(denseIntOrFPElementsAttributeCaster))); PyDenseResourceElementsAttribute::bind(m); PyDictAttribute::bind(m); PySymbolRefAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirSymbolRefAttrGetTypeID(), - pybind11::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster)); + nb::cast( + nb::cpp_function(symbolRefOrFlatSymbolRefAttributeCaster))); PyFlatSymbolRefAttribute::bind(m); PyOpaqueAttribute::bind(m); @@ -1590,7 +1801,7 @@ void mlir::python::populateIRAttributes(py::module &m) { PyTypeAttribute::bind(m); PyGlobals::get().registerTypeCaster( mlirIntegerAttrGetTypeID(), - pybind11::cpp_function(integerOrBoolAttributeCaster)); + nb::cast(nb::cpp_function(integerOrBoolAttributeCaster))); PyUnitAttribute::bind(m); PyStridedLayoutAttribute::bind(m); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 3e96f8c60..e1c56a398 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,26 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "IRModule.h" +#include +#include +#include +#include +#include +#include -#include "Globals.h" -#include "PybindUtils.h" +#include +#include +#include "Globals.h" +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include -#include - -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -190,18 +195,18 @@ operations. /// Helper for creating an @classmethod. template -py::object classmethod(Func f, Args... args) { - py::object cf = py::cpp_function(f, args...); - return py::reinterpret_borrow((PyClassMethod_New(cf.ptr()))); +nb::object classmethod(Func f, Args... args) { + nb::object cf = nb::cpp_function(f, args...); + return nb::borrow((PyClassMethod_New(cf.ptr()))); } -static py::object +static nb::object createCustomDialectWrapper(const std::string &dialectNamespace, - py::object dialectDescriptor) { + nb::object dialectDescriptor) { auto dialectClass = PyGlobals::get().lookupDialectClass(dialectNamespace); if (!dialectClass) { // Use the base class. - return py::cast(PyDialect(std::move(dialectDescriptor))); + return nb::cast(PyDialect(std::move(dialectDescriptor))); } // Create the custom implementation. @@ -212,42 +217,47 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(const nb::bytes &s) { + return mlirStringRefCreate(static_cast(s.data()), s.size()); +} + /// Create a block, using the current location context if no locations are /// specified. -static MlirBlock createBlock(const py::sequence &pyArgTypes, - const std::optional &pyArgLocs) { +static MlirBlock createBlock(const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { SmallVector argTypes; - argTypes.reserve(pyArgTypes.size()); + argTypes.reserve(nb::len(pyArgTypes)); for (const auto &pyType : pyArgTypes) - argTypes.push_back(pyType.cast()); + argTypes.push_back(nb::cast(pyType)); SmallVector argLocs; if (pyArgLocs) { - argLocs.reserve(pyArgLocs->size()); + argLocs.reserve(nb::len(*pyArgLocs)); for (const auto &pyLoc : *pyArgLocs) - argLocs.push_back(pyLoc.cast()); + argLocs.push_back(nb::cast(pyLoc)); } else if (!argTypes.empty()) { argLocs.assign(argTypes.size(), DefaultingPyLocation::resolve()); } if (argTypes.size() != argLocs.size()) - throw py::value_error(("Expected " + Twine(argTypes.size()) + + throw nb::value_error(("Expected " + Twine(argTypes.size()) + " locations, got: " + Twine(argLocs.size())) - .str()); + .str() + .c_str()); return mlirBlockCreate(argTypes.size(), argTypes.data(), argLocs.data()); } /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(py::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const py::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } - static void bind(py::module &m) { + static void bind(nb::module_ &m) { // Debug flags. - py::class_(m, "_GlobalDebug", py::module_local()) - .def_property_static("flag", &PyGlobalDebugFlag::get, - &PyGlobalDebugFlag::set, "LLVM-wide debug flag") + nb::class_(m, "_GlobalDebug") + .def_prop_rw_static("flag", &PyGlobalDebugFlag::get, + &PyGlobalDebugFlag::set, "LLVM-wide debug flag") .def_static( "set_types", [](const std::string &type) { @@ -268,20 +278,20 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static py::function dundeGetItemNamed(const std::string &attributeKind) { + static nb::callable dundeGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) - throw py::key_error(attributeKind); + throw nb::key_error(attributeKind.c_str()); return *builder; } static void dundeSetItemNamed(const std::string &attributeKind, - py::function func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } - static void bind(py::module &m) { - py::class_(m, "AttrBuilder", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "AttrBuilder") .def_static("contains", &PyAttrBuilderMap::dunderContains) .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, @@ -295,8 +305,8 @@ struct PyAttrBuilderMap { // PyBlock //------------------------------------------------------------------------------ -py::object PyBlock::getCapsule() { - return py::reinterpret_steal(mlirPythonBlockToCapsule(get())); +nb::object PyBlock::getCapsule() { + return nb::steal(mlirPythonBlockToCapsule(get())); } //------------------------------------------------------------------------------ @@ -315,14 +325,14 @@ class PyRegionIterator { PyRegion dunderNext() { operation->checkValid(); if (nextIndex >= mlirOperationGetNumRegions(operation->get())) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } MlirRegion region = mlirOperationGetRegion(operation->get(), nextIndex++); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionIterator") .def("__iter__", &PyRegionIterator::dunderIter) .def("__next__", &PyRegionIterator::dunderNext); } @@ -351,14 +361,14 @@ class PyRegionList { PyRegion dunderGetItem(intptr_t index) { // dunderLen checks validity. if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds region"); + throw nb::index_error("attempt to access out of bounds region"); } MlirRegion region = mlirOperationGetRegion(operation->get(), index); return PyRegion(operation, region); } - static void bind(py::module &m) { - py::class_(m, "RegionSequence", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "RegionSequence") .def("__len__", &PyRegionList::dunderLen) .def("__iter__", &PyRegionList::dunderIter) .def("__getitem__", &PyRegionList::dunderGetItem); @@ -378,7 +388,7 @@ class PyBlockIterator { PyBlock dunderNext() { operation->checkValid(); if (mlirBlockIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyBlock returnBlock(operation, next); @@ -386,8 +396,8 @@ class PyBlockIterator { return returnBlock; } - static void bind(py::module &m) { - py::class_(m, "BlockIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockIterator") .def("__iter__", &PyBlockIterator::dunderIter) .def("__next__", &PyBlockIterator::dunderNext); } @@ -424,7 +434,7 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } MlirBlock block = mlirRegionGetFirstBlock(region); while (!mlirBlockIsNull(block)) { @@ -434,24 +444,26 @@ class PyBlockList { block = mlirBlockGetNextInRegion(block); index -= 1; } - throw py::index_error("attempt to access out of bounds block"); + throw nb::index_error("attempt to access out of bounds block"); } - PyBlock appendBlock(const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + PyBlock appendBlock(const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { operation->checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); mlirRegionAppendOwnedBlock(region, block); return PyBlock(operation, block); } - static void bind(py::module &m) { - py::class_(m, "BlockList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "BlockList") .def("__getitem__", &PyBlockList::dunderGetItem) .def("__iter__", &PyBlockList::dunderIter) .def("__len__", &PyBlockList::dunderLen) .def("append", &PyBlockList::appendBlock, kAppendBlockDocstring, - py::arg("arg_locs") = std::nullopt); + nb::arg("args"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt); } private: @@ -466,10 +478,10 @@ class PyOperationIterator { PyOperationIterator &dunderIter() { return *this; } - py::object dunderNext() { + nb::object dunderNext() { parentOperation->checkValid(); if (mlirOperationIsNull(next)) { - throw py::stop_iteration(); + throw nb::stop_iteration(); } PyOperationRef returnOperation = @@ -478,8 +490,8 @@ class PyOperationIterator { return returnOperation->createOpView(); } - static void bind(py::module &m) { - py::class_(m, "OperationIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationIterator") .def("__iter__", &PyOperationIterator::dunderIter) .def("__next__", &PyOperationIterator::dunderNext); } @@ -515,10 +527,10 @@ class PyOperationList { return count; } - py::object dunderGetItem(intptr_t index) { + nb::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); if (index < 0) { - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } MlirOperation childOp = mlirBlockGetFirstOperation(block); while (!mlirOperationIsNull(childOp)) { @@ -529,11 +541,11 @@ class PyOperationList { childOp = mlirOperationGetNextInBlock(childOp); index -= 1; } - throw py::index_error("attempt to access out of bounds operation"); + throw nb::index_error("attempt to access out of bounds operation"); } - static void bind(py::module &m) { - py::class_(m, "OperationList", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OperationList") .def("__getitem__", &PyOperationList::dunderGetItem) .def("__iter__", &PyOperationList::dunderIter) .def("__len__", &PyOperationList::dunderLen); @@ -548,7 +560,7 @@ class PyOpOperand { public: PyOpOperand(MlirOpOperand opOperand) : opOperand(opOperand) {} - py::object getOwner() { + nb::object getOwner() { MlirOperation owner = mlirOpOperandGetOwner(opOperand); PyMlirContextRef context = PyMlirContext::forContext(mlirOperationGetContext(owner)); @@ -557,11 +569,10 @@ class PyOpOperand { size_t getOperandNumber() { return mlirOpOperandGetOperandNumber(opOperand); } - static void bind(py::module &m) { - py::class_(m, "OpOperand", py::module_local()) - .def_property_readonly("owner", &PyOpOperand::getOwner) - .def_property_readonly("operand_number", - &PyOpOperand::getOperandNumber); + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperand") + .def_prop_ro("owner", &PyOpOperand::getOwner) + .def_prop_ro("operand_number", &PyOpOperand::getOperandNumber); } private: @@ -576,15 +587,15 @@ class PyOpOperandIterator { PyOpOperand dunderNext() { if (mlirOpOperandIsNull(opOperand)) - throw py::stop_iteration(); + throw nb::stop_iteration(); PyOpOperand returnOpOperand(opOperand); opOperand = mlirOpOperandGetNextUse(opOperand); return returnOpOperand; } - static void bind(py::module &m) { - py::class_(m, "OpOperandIterator", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpOperandIterator") .def("__iter__", &PyOpOperandIterator::dunderIter) .def("__next__", &PyOpOperandIterator::dunderNext); } @@ -600,7 +611,7 @@ class PyOpOperandIterator { //------------------------------------------------------------------------------ PyMlirContext::PyMlirContext(MlirContext context) : context(context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -609,41 +620,36 @@ PyMlirContext::~PyMlirContext() { // Note that the only public way to construct an instance is via the // forContext method, which always puts the associated handle into // liveContexts. - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } -py::object PyMlirContext::getCapsule() { - return py::reinterpret_steal(mlirPythonContextToCapsule(get())); +nb::object PyMlirContext::getCapsule() { + return nb::steal(mlirPythonContextToCapsule(get())); } -py::object PyMlirContext::createFromCapsule(py::object capsule) { +nb::object PyMlirContext::createFromCapsule(nb::object capsule) { MlirContext rawContext = mlirPythonCapsuleToContext(capsule.ptr()); if (mlirContextIsNull(rawContext)) - throw py::error_already_set(); + throw nb::python_error(); return forContext(rawContext).releaseObject(); } -PyMlirContext *PyMlirContext::createNewContextForInit() { - MlirContext context = mlirContextCreateWithThreading(false); - return new PyMlirContext(context); -} - PyMlirContextRef PyMlirContext::forContext(MlirContext context) { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { // Create. PyMlirContext *unownedContextWrapper = new PyMlirContext(context); - py::object pyRef = py::cast(unownedContextWrapper); - assert(pyRef && "cast to py::object failed"); + nb::object pyRef = nb::cast(unownedContextWrapper); + assert(pyRef && "cast to nb::object failed"); liveContexts[context.ptr] = unownedContextWrapper; return PyMlirContextRef(unownedContextWrapper, std::move(pyRef)); } // Use existing. - py::object pyRef = py::cast(it->second); + nb::object pyRef = nb::cast(it->second); return PyMlirContextRef(it->second, std::move(pyRef)); } @@ -717,23 +723,23 @@ void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } -pybind11::object PyMlirContext::contextEnter() { - return PyThreadContextEntry::pushContext(*this); +nb::object PyMlirContext::contextEnter(nb::object context) { + return PyThreadContextEntry::pushContext(context); } -void PyMlirContext::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyMlirContext::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popContext(*this); } -py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { +nb::object PyMlirContext::attachDiagnosticHandler(nb::object callback) { // Note that ownership is transferred to the delete callback below by way of // an explicit inc_ref (borrow). PyDiagnosticHandler *pyHandler = new PyDiagnosticHandler(get(), std::move(callback)); - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::take_ownership); + nb::object pyHandlerObject = + nb::cast(pyHandler, nb::rv_policy::take_ownership); pyHandlerObject.inc_ref(); // In these C callbacks, the userData is a PyDiagnosticHandler* that is @@ -741,17 +747,17 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { auto handlerCallback = +[](MlirDiagnostic diagnostic, void *userData) -> MlirLogicalResult { PyDiagnostic *pyDiagnostic = new PyDiagnostic(diagnostic); - py::object pyDiagnosticObject = - py::cast(pyDiagnostic, py::return_value_policy::take_ownership); + nb::object pyDiagnosticObject = + nb::cast(pyDiagnostic, nb::rv_policy::take_ownership); auto *pyHandler = static_cast(userData); bool result = false; { // Since this can be called from arbitrary C++ contexts, always get the // gil. - py::gil_scoped_acquire gil; + nb::gil_scoped_acquire gil; try { - result = py::cast(pyHandler->callback(pyDiagnostic)); + result = nb::cast(pyHandler->callback(pyDiagnostic)); } catch (std::exception &e) { fprintf(stderr, "MLIR Python Diagnostic handler raised exception: %s\n", e.what()); @@ -768,8 +774,7 @@ py::object PyMlirContext::attachDiagnosticHandler(py::object callback) { pyHandler->registeredID.reset(); // Decrement reference, balancing the inc_ref() above. - py::object pyHandlerObject = - py::cast(pyHandler, py::return_value_policy::reference); + nb::object pyHandlerObject = nb::cast(pyHandler, nb::rv_policy::reference); pyHandlerObject.dec_ref(); }; @@ -819,9 +824,9 @@ PyThreadContextEntry *PyThreadContextEntry::getTopOfStack() { return &stack.back(); } -void PyThreadContextEntry::push(FrameKind frameKind, py::object context, - py::object insertionPoint, - py::object location) { +void PyThreadContextEntry::push(FrameKind frameKind, nb::object context, + nb::object insertionPoint, + nb::object location) { auto &stack = getStack(); stack.emplace_back(frameKind, std::move(context), std::move(insertionPoint), std::move(location)); @@ -844,19 +849,19 @@ void PyThreadContextEntry::push(FrameKind frameKind, py::object context, PyMlirContext *PyThreadContextEntry::getContext() { if (!context) return nullptr; - return py::cast(context); + return nb::cast(context); } PyInsertionPoint *PyThreadContextEntry::getInsertionPoint() { if (!insertionPoint) return nullptr; - return py::cast(insertionPoint); + return nb::cast(insertionPoint); } PyLocation *PyThreadContextEntry::getLocation() { if (!location) return nullptr; - return py::cast(location); + return nb::cast(location); } PyMlirContext *PyThreadContextEntry::getDefaultContext() { @@ -874,12 +879,11 @@ PyLocation *PyThreadContextEntry::getDefaultLocation() { return tos ? tos->getLocation() : nullptr; } -py::object PyThreadContextEntry::pushContext(PyMlirContext &context) { - py::object contextObj = py::cast(context); - push(FrameKind::Context, /*context=*/contextObj, - /*insertionPoint=*/py::object(), - /*location=*/py::object()); - return contextObj; +nb::object PyThreadContextEntry::pushContext(nb::object context) { + push(FrameKind::Context, /*context=*/context, + /*insertionPoint=*/nb::object(), + /*location=*/nb::object()); + return context; } void PyThreadContextEntry::popContext(PyMlirContext &context) { @@ -892,15 +896,16 @@ void PyThreadContextEntry::popContext(PyMlirContext &context) { stack.pop_back(); } -py::object -PyThreadContextEntry::pushInsertionPoint(PyInsertionPoint &insertionPoint) { - py::object contextObj = +nb::object +PyThreadContextEntry::pushInsertionPoint(nb::object insertionPointObj) { + PyInsertionPoint &insertionPoint = + nb::cast(insertionPointObj); + nb::object contextObj = insertionPoint.getBlock().getParentOperation()->getContext().getObject(); - py::object insertionPointObj = py::cast(insertionPoint); push(FrameKind::InsertionPoint, /*context=*/contextObj, /*insertionPoint=*/insertionPointObj, - /*location=*/py::object()); + /*location=*/nb::object()); return insertionPointObj; } @@ -915,11 +920,11 @@ void PyThreadContextEntry::popInsertionPoint(PyInsertionPoint &insertionPoint) { stack.pop_back(); } -py::object PyThreadContextEntry::pushLocation(PyLocation &location) { - py::object contextObj = location.getContext().getObject(); - py::object locationObj = py::cast(location); +nb::object PyThreadContextEntry::pushLocation(nb::object locationObj) { + PyLocation &location = nb::cast(locationObj); + nb::object contextObj = location.getContext().getObject(); push(FrameKind::Location, /*context=*/contextObj, - /*insertionPoint=*/py::object(), + /*insertionPoint=*/nb::object(), /*location=*/locationObj); return locationObj; } @@ -941,15 +946,15 @@ void PyThreadContextEntry::popLocation(PyLocation &location) { void PyDiagnostic::invalidate() { valid = false; if (materializedNotes) { - for (auto ¬eObject : *materializedNotes) { - PyDiagnostic *note = py::cast(noteObject); + for (nb::handle noteObject : *materializedNotes) { + PyDiagnostic *note = nb::cast(noteObject); note->invalidate(); } } } PyDiagnosticHandler::PyDiagnosticHandler(MlirContext context, - py::object callback) + nb::object callback) : context(context), callback(std::move(callback)) {} PyDiagnosticHandler::~PyDiagnosticHandler() = default; @@ -984,32 +989,36 @@ PyLocation PyDiagnostic::getLocation() { return PyLocation(PyMlirContext::forContext(context), loc); } -py::str PyDiagnostic::getMessage() { +nb::str PyDiagnostic::getMessage() { checkValid(); - py::object fileObject = py::module::import("io").attr("StringIO")(); + nb::object fileObject = nb::module_::import_("io").attr("StringIO")(); PyFileAccumulator accum(fileObject, /*binary=*/false); mlirDiagnosticPrint(diagnostic, accum.getCallback(), accum.getUserData()); - return fileObject.attr("getvalue")(); + return nb::cast(fileObject.attr("getvalue")()); } -py::tuple PyDiagnostic::getNotes() { +nb::tuple PyDiagnostic::getNotes() { checkValid(); if (materializedNotes) return *materializedNotes; intptr_t numNotes = mlirDiagnosticGetNumNotes(diagnostic); - materializedNotes = py::tuple(numNotes); + nb::tuple notes = nb::steal(PyTuple_New(numNotes)); for (intptr_t i = 0; i < numNotes; ++i) { MlirDiagnostic noteDiag = mlirDiagnosticGetNote(diagnostic, i); - (*materializedNotes)[i] = PyDiagnostic(noteDiag); + nb::object diagnostic = nb::cast(PyDiagnostic(noteDiag)); + PyTuple_SET_ITEM(notes.ptr(), i, diagnostic.release().ptr()); } + materializedNotes = std::move(notes); + return *materializedNotes; } PyDiagnostic::DiagnosticInfo PyDiagnostic::getInfo() { std::vector notes; - for (py::handle n : getNotes()) - notes.emplace_back(n.cast().getInfo()); - return {getSeverity(), getLocation(), getMessage(), std::move(notes)}; + for (nb::handle n : getNotes()) + notes.emplace_back(nb::cast(n).getInfo()); + return {getSeverity(), getLocation(), nb::cast(getMessage()), + std::move(notes)}; } //------------------------------------------------------------------------------ @@ -1023,22 +1032,21 @@ MlirDialect PyDialects::getDialectForKey(const std::string &key, if (mlirDialectIsNull(dialect)) { std::string msg = (Twine("Dialect '") + key + "' not found").str(); if (attrError) - throw py::attribute_error(msg); - throw py::index_error(msg); + throw nb::attribute_error(msg.c_str()); + throw nb::index_error(msg.c_str()); } return dialect; } -py::object PyDialectRegistry::getCapsule() { - return py::reinterpret_steal( - mlirPythonDialectRegistryToCapsule(*this)); +nb::object PyDialectRegistry::getCapsule() { + return nb::steal(mlirPythonDialectRegistryToCapsule(*this)); } -PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { +PyDialectRegistry PyDialectRegistry::createFromCapsule(nb::object capsule) { MlirDialectRegistry rawRegistry = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); if (mlirDialectRegistryIsNull(rawRegistry)) - throw py::error_already_set(); + throw nb::python_error(); return PyDialectRegistry(rawRegistry); } @@ -1046,25 +1054,25 @@ PyDialectRegistry PyDialectRegistry::createFromCapsule(py::object capsule) { // PyLocation //------------------------------------------------------------------------------ -py::object PyLocation::getCapsule() { - return py::reinterpret_steal(mlirPythonLocationToCapsule(*this)); +nb::object PyLocation::getCapsule() { + return nb::steal(mlirPythonLocationToCapsule(*this)); } -PyLocation PyLocation::createFromCapsule(py::object capsule) { +PyLocation PyLocation::createFromCapsule(nb::object capsule) { MlirLocation rawLoc = mlirPythonCapsuleToLocation(capsule.ptr()); if (mlirLocationIsNull(rawLoc)) - throw py::error_already_set(); + throw nb::python_error(); return PyLocation(PyMlirContext::forContext(mlirLocationGetContext(rawLoc)), rawLoc); } -py::object PyLocation::contextEnter() { - return PyThreadContextEntry::pushLocation(*this); +nb::object PyLocation::contextEnter(nb::object locationObj) { + return PyThreadContextEntry::pushLocation(locationObj); } -void PyLocation::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyLocation::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popLocation(*this); } @@ -1087,7 +1095,7 @@ PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} PyModule::~PyModule() { - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = getContext()->liveModules; assert(liveModules.count(module.ptr) == 1 && "destroying module not in live map"); @@ -1099,7 +1107,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - py::gil_scoped_acquire acquire; + nb::gil_scoped_acquire acquire; auto &liveModules = contextRef->liveModules; auto it = liveModules.find(module.ptr); if (it == liveModules.end()) { @@ -1108,8 +1116,7 @@ PyModuleRef PyModule::forModule(MlirModule module) { // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedModule, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); unownedModule->handle = pyRef; liveModules[module.ptr] = std::make_pair(unownedModule->handle, unownedModule); @@ -1117,19 +1124,19 @@ PyModuleRef PyModule::forModule(MlirModule module) { } // Use existing. PyModule *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyModuleRef(existing, std::move(pyRef)); } -py::object PyModule::createFromCapsule(py::object capsule) { +nb::object PyModule::createFromCapsule(nb::object capsule) { MlirModule rawModule = mlirPythonCapsuleToModule(capsule.ptr()); if (mlirModuleIsNull(rawModule)) - throw py::error_already_set(); + throw nb::python_error(); return forModule(rawModule).releaseObject(); } -py::object PyModule::getCapsule() { - return py::reinterpret_steal(mlirPythonModuleToCapsule(get())); +nb::object PyModule::getCapsule() { + return nb::steal(mlirPythonModuleToCapsule(get())); } //------------------------------------------------------------------------------ @@ -1158,7 +1165,7 @@ PyOperation::~PyOperation() { PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; // Create. PyOperation *unownedOperation = @@ -1166,8 +1173,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, // Note that the default return value policy on cast is automatic_reference, // which does not take ownership (delete will not be called). // Just be explicit. - py::object pyRef = - py::cast(unownedOperation, py::return_value_policy::take_ownership); + nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership); unownedOperation->handle = pyRef; if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); @@ -1178,7 +1184,7 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); if (it == liveOperations.end()) { @@ -1188,13 +1194,13 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, } // Use existing. PyOperation *existing = it->second.second; - py::object pyRef = py::reinterpret_borrow(it->second.first); + nb::object pyRef = nb::borrow(it->second.first); return PyOperationRef(existing, std::move(pyRef)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, - py::object parentKeepAlive) { + nb::object parentKeepAlive) { auto &liveOperations = contextRef->liveOperations; assert(liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"); @@ -1227,12 +1233,12 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, + bool assumeVerified, nb::object fileObject, bool binary, bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (largeElementsLimit) @@ -1255,18 +1261,18 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsDestroy(flags); } -void PyOperationBase::print(PyAsmState &state, py::object fileObject, +void PyOperationBase::print(PyAsmState &state, nb::object fileObject, bool binary) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) - fileObject = py::module::import("sys").attr("stdout"); + fileObject = nb::module_::import_("sys").attr("stdout"); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithState(operation, state.get(), accum.getCallback(), accum.getUserData()); } -void PyOperationBase::writeBytecode(const py::object &fileObject, +void PyOperationBase::writeBytecode(const nb::object &fileObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); @@ -1282,9 +1288,10 @@ void PyOperationBase::writeBytecode(const py::object &fileObject, operation, config, accum.getCallback(), accum.getUserData()); mlirBytecodeWriterConfigDestroy(config); if (mlirLogicalResultIsFailure(res)) - throw py::value_error((Twine("Unable to honor desired bytecode version ") + + throw nb::value_error((Twine("Unable to honor desired bytecode version ") + Twine(*bytecodeVersion)) - .str()); + .str() + .c_str()); } void PyOperationBase::walk( @@ -1296,7 +1303,7 @@ void PyOperationBase::walk( std::function callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{callback, false, {}, {}}; MlirOperationWalkCallback walkCallback = [](MlirOperation op, @@ -1304,10 +1311,10 @@ void PyOperationBase::walk( UserData *calleeUserData = static_cast(userData); try { return (calleeUserData->callback)(op); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; - calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionWhat = std::string(e.what()); + calleeUserData->exceptionType = nb::borrow(e.type()); return MlirWalkResult::MlirWalkResultInterrupt; } }; @@ -1319,16 +1326,16 @@ void PyOperationBase::walk( } } -py::object PyOperationBase::getAsm(bool binary, +nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions) { - py::object fileObject; + nb::object fileObject; if (binary) { - fileObject = py::module::import("io").attr("BytesIO")(); + fileObject = nb::module_::import_("io").attr("BytesIO")(); } else { - fileObject = py::module::import("io").attr("StringIO")(); + fileObject = nb::module_::import_("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, /*enableDebugInfo=*/enableDebugInfo, @@ -1372,7 +1379,7 @@ bool PyOperationBase::verify() { std::optional PyOperation::getParentOperation() { checkValid(); if (!isAttached()) - throw py::value_error("Detached operations have no parent"); + throw nb::value_error("Detached operations have no parent"); MlirOperation operation = mlirOperationGetParentOperation(get()); if (mlirOperationIsNull(operation)) return {}; @@ -1388,42 +1395,42 @@ PyBlock PyOperation::getBlock() { return PyBlock{std::move(*parentOperation), block}; } -py::object PyOperation::getCapsule() { +nb::object PyOperation::getCapsule() { checkValid(); - return py::reinterpret_steal(mlirPythonOperationToCapsule(get())); + return nb::steal(mlirPythonOperationToCapsule(get())); } -py::object PyOperation::createFromCapsule(py::object capsule) { +nb::object PyOperation::createFromCapsule(nb::object capsule) { MlirOperation rawOperation = mlirPythonCapsuleToOperation(capsule.ptr()); if (mlirOperationIsNull(rawOperation)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext rawCtxt = mlirOperationGetContext(rawOperation); return forOperation(PyMlirContext::forContext(rawCtxt), rawOperation) .releaseObject(); } static void maybeInsertOperation(PyOperationRef &op, - const py::object &maybeIp) { + const nb::object &maybeIp) { // InsertPoint active? - if (!maybeIp.is(py::cast(false))) { + if (!maybeIp.is(nb::cast(false))) { PyInsertionPoint *ip; if (maybeIp.is_none()) { ip = PyThreadContextEntry::getDefaultInsertionPoint(); } else { - ip = py::cast(maybeIp); + ip = nb::cast(maybeIp); } if (ip) ip->insert(*op.get()); } } -py::object PyOperation::create(const std::string &name, +nb::object PyOperation::create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, - const py::object &maybeIp, bool inferType) { + const nb::object &maybeIp, bool inferType) { llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1431,14 +1438,14 @@ py::object PyOperation::create(const std::string &name, // General parameter validation. if (regions < 0) - throw py::value_error("number of regions must be >= 0"); + throw nb::value_error("number of regions must be >= 0"); // Unpack/validate operands. if (operands) { mlirOperands.reserve(operands->size()); for (PyValue *operand : *operands) { if (!operand) - throw py::value_error("operand value cannot be None"); + throw nb::value_error("operand value cannot be None"); mlirOperands.push_back(operand->get()); } } @@ -1449,38 +1456,38 @@ py::object PyOperation::create(const std::string &name, for (PyType *result : *results) { // TODO: Verify result type originate from the same context. if (!result) - throw py::value_error("result type cannot be None"); + throw nb::value_error("result type cannot be None"); mlirResults.push_back(*result); } } // Unpack/validate attributes. if (attributes) { mlirAttributes.reserve(attributes->size()); - for (auto &it : *attributes) { + for (std::pair it : *attributes) { std::string key; try { - key = it.first.cast(); - } catch (py::cast_error &err) { + key = nb::cast(it.first); + } catch (nb::cast_error &err) { std::string msg = "Invalid attribute key (not a string) when " "attempting to create the operation \"" + name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw nb::type_error(msg.c_str()); } try { - auto &attribute = it.second.cast(); + auto &attribute = nb::cast(it.second); // TODO: Verify attribute originates from the same context. mlirAttributes.emplace_back(std::move(key), attribute); - } catch (py::reference_cast_error &) { + } catch (nb::cast_error &err) { + std::string msg = "Invalid attribute value for the key \"" + key + + "\" when attempting to create the operation \"" + + name + "\" (" + err.what() + ")"; + throw nb::type_error(msg.c_str()); + } catch (std::runtime_error &) { // This exception seems thrown when the value is "None". std::string msg = "Found an invalid (`None`?) attribute value for the key \"" + key + "\" when attempting to create the operation \"" + name + "\""; - throw py::cast_error(msg); - } catch (py::cast_error &err) { - std::string msg = "Invalid attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; - throw py::cast_error(msg); + throw std::runtime_error(msg); } } } @@ -1490,7 +1497,7 @@ py::object PyOperation::create(const std::string &name, for (auto *successor : *successors) { // TODO: Verify successor originate from the same context. if (!successor) - throw py::value_error("successor block cannot be None"); + throw nb::value_error("successor block cannot be None"); mlirSuccessors.push_back(successor->get()); } } @@ -1535,7 +1542,7 @@ py::object PyOperation::create(const std::string &name, // Construct the operation. MlirOperation operation = mlirOperationCreate(&state); if (!operation.ptr) - throw py::value_error("Operation creation failed"); + throw nb::value_error("Operation creation failed"); PyOperationRef created = PyOperation::createDetached(location->getContext(), operation); maybeInsertOperation(created, maybeIp); @@ -1543,7 +1550,7 @@ py::object PyOperation::create(const std::string &name, return created.getObject(); } -py::object PyOperation::clone(const py::object &maybeIp) { +nb::object PyOperation::clone(const nb::object &maybeIp) { MlirOperation clonedOperation = mlirOperationClone(operation); PyOperationRef cloned = PyOperation::createDetached(getContext(), clonedOperation); @@ -1552,15 +1559,15 @@ py::object PyOperation::clone(const py::object &maybeIp) { return cloned->createOpView(); } -py::object PyOperation::createOpView() { +nb::object PyOperation::createOpView() { checkValid(); MlirIdentifier ident = mlirOperationGetName(get()); MlirStringRef identStr = mlirIdentifierStr(ident); auto operationCls = PyGlobals::get().lookupOperationClass( StringRef(identStr.data, identStr.length)); if (operationCls) - return PyOpView::constructDerived(*operationCls, *getRef().get()); - return py::cast(PyOpView(getRef().getObject())); + return PyOpView::constructDerived(*operationCls, getRef().getObject()); + return nb::cast(PyOpView(getRef().getObject())); } void PyOperation::erase() { @@ -1573,8 +1580,8 @@ void PyOperation::erase() { // PyOpView //------------------------------------------------------------------------------ -static void populateResultTypes(StringRef name, py::list resultTypeList, - const py::object &resultSegmentSpecObj, +static void populateResultTypes(StringRef name, nb::list resultTypeList, + const nb::object &resultSegmentSpecObj, std::vector &resultSegmentLengths, std::vector &resultTypes) { resultTypes.reserve(resultTypeList.size()); @@ -1582,26 +1589,28 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Non-variadic result unpacking. for (const auto &it : llvm::enumerate(resultTypeList)) { try { - resultTypes.push_back(py::cast(it.value())); + resultTypes.push_back(nb::cast(it.value())); if (!resultTypes.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized result unpacking. - auto resultSegmentSpec = py::cast>(resultSegmentSpecObj); + auto resultSegmentSpec = nb::cast>(resultSegmentSpecObj); if (resultSegmentSpec.size() != resultTypeList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(resultSegmentSpec.size()) + " result segments but was provided " + llvm::Twine(resultTypeList.size())) - .str()); + .str() + .c_str()); } resultSegmentLengths.reserve(resultTypeList.size()); for (const auto &it : @@ -1610,7 +1619,7 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *resultType = py::cast(std::get<0>(it.value())); + auto *resultType = nb::cast(std::get<0>(it.value())); if (resultType) { resultTypes.push_back(resultType); resultSegmentLengths.push_back(1); @@ -1618,14 +1627,20 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, // Allowed to be optional. resultSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and result is not optional"); + throw nb::value_error( + (llvm::Twine("Result ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Type (was None and result is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Result ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Type (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1635,72 +1650,75 @@ static void populateResultTypes(StringRef name, py::list resultTypeList, resultSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - resultTypes.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + resultTypes.push_back(nb::cast(segmentItem)); if (!resultTypes.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - resultSegmentLengths.push_back(segment.size()); + resultSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Result ") + + throw nb::value_error((llvm::Twine("Result ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Types (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } } -py::object PyOpView::buildGeneric( - const py::object &cls, std::optional resultTypeList, - py::list operandList, std::optional attributes, +nb::object PyOpView::buildGeneric( + const nb::object &cls, std::optional resultTypeList, + nb::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const py::object &maybeIp) { + const nb::object &maybeIp) { PyMlirContextRef context = location->getContext(); // Class level operation construction metadata. - std::string name = py::cast(cls.attr("OPERATION_NAME")); + std::string name = nb::cast(cls.attr("OPERATION_NAME")); // Operand and result segment specs are either none, which does no // variadic unpacking, or a list of ints with segment sizes, where each // element is either a positive number (typically 1 for a scalar) or -1 to // indicate that it is derived from the length of the same-indexed operand // or result (implying that it is a list at that position). - py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); + nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); + nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); std::vector operandSegmentLengths; std::vector resultSegmentLengths; // Validate/determine region count. - auto opRegionSpec = py::cast>(cls.attr("_ODS_REGIONS")); + auto opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); int opMinRegionCount = std::get<0>(opRegionSpec); bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); if (!regions) { regions = opMinRegionCount; } if (*regions < opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a minimum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } if (opHasNoVariadicRegions && *regions > opMinRegionCount) { - throw py::value_error( + throw nb::value_error( (llvm::Twine("Operation \"") + name + "\" requires a maximum of " + llvm::Twine(opMinRegionCount) + " regions but was built with regions=" + llvm::Twine(*regions)) - .str()); + .str() + .c_str()); } // Unpack results. @@ -1717,26 +1735,28 @@ py::object PyOpView::buildGeneric( // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(py::cast(it.value())); + operands.push_back(nb::cast(it.value())); if (!operands.back()) - throw py::cast_error(); - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + throw nb::cast_error(); + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } } else { // Sized operand unpacking. - auto operandSegmentSpec = py::cast>(operandSegmentSpecObj); + auto operandSegmentSpec = nb::cast>(operandSegmentSpecObj); if (operandSegmentSpec.size() != operandList.size()) { - throw py::value_error((llvm::Twine("Operation \"") + name + + throw nb::value_error((llvm::Twine("Operation \"") + name + "\" requires " + llvm::Twine(operandSegmentSpec.size()) + "operand segments but was provided " + llvm::Twine(operandList.size())) - .str()); + .str() + .c_str()); } operandSegmentLengths.reserve(operandList.size()); for (const auto &it : @@ -1745,7 +1765,7 @@ py::object PyOpView::buildGeneric( if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. try { - auto *operandValue = py::cast(std::get<0>(it.value())); + auto *operandValue = nb::cast(std::get<0>(it.value())); if (operandValue) { operands.push_back(operandValue); operandSegmentLengths.push_back(1); @@ -1753,14 +1773,20 @@ py::object PyOpView::buildGeneric( // Allowed to be optional. operandSegmentLengths.push_back(0); } else { - throw py::cast_error("was None and operand is not optional"); + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); } - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1770,27 +1796,28 @@ py::object PyOpView::buildGeneric( operandSegmentLengths.push_back(0); } else { // Unpack the list. - auto segment = py::cast(std::get<0>(it.value())); - for (py::object segmentItem : segment) { - operands.push_back(py::cast(segmentItem)); + auto segment = nb::cast(std::get<0>(it.value())); + for (nb::handle segmentItem : segment) { + operands.push_back(nb::cast(segmentItem)); if (!operands.back()) { - throw py::cast_error("contained a None item"); + throw nb::type_error("contained a None item"); } } - operandSegmentLengths.push_back(segment.size()); + operandSegmentLengths.push_back(nb::len(segment)); } } catch (std::exception &err) { // NOTE: Sloppy to be using a catch-all here, but there are at least // three different unrelated exceptions that can be thrown in the // above "casts". Just keep the scope above small and catch them all. - throw py::value_error((llvm::Twine("Operand ") + + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } else { - throw py::value_error("Unexpected segment spec"); + throw nb::value_error("Unexpected segment spec"); } } } @@ -1799,13 +1826,13 @@ py::object PyOpView::buildGeneric( if (!operandSegmentLengths.empty() || !resultSegmentLengths.empty()) { // Dup. if (attributes) { - attributes = py::dict(*attributes); + attributes = nb::dict(*attributes); } else { - attributes = py::dict(); + attributes = nb::dict(); } if (attributes->contains("resultSegmentSizes") || attributes->contains("operandSegmentSizes")) { - throw py::value_error("Manually setting a 'resultSegmentSizes' or " + throw nb::value_error("Manually setting a 'resultSegmentSizes' or " "'operandSegmentSizes' attribute is unsupported. " "Use Operation.create for such low-level access."); } @@ -1839,21 +1866,18 @@ py::object PyOpView::buildGeneric( !resultTypeList); } -pybind11::object PyOpView::constructDerived(const pybind11::object &cls, - const PyOperation &operation) { - // TODO: pybind11 2.6 supports a more direct form. - // Upgrade many years from now. - // auto opViewType = py::type::of(); - py::handle opViewType = py::detail::get_type_handle(typeid(PyOpView), true); - py::object instance = cls.attr("__new__")(cls); +nb::object PyOpView::constructDerived(const nb::object &cls, + const nb::object &operation) { + nb::handle opViewType = nb::type(); + nb::object instance = cls.attr("__new__")(cls); opViewType.attr("__init__")(instance, operation); return instance; } -PyOpView::PyOpView(const py::object &operationObject) +PyOpView::PyOpView(const nb::object &operationObject) // Casting through the PyOperationBase base-class and then back to the // Operation lets us accept any PyOperationBase subclass. - : operation(py::cast(operationObject).getOperation()), + : operation(nb::cast(operationObject).getOperation()), operationObject(operation.getRef().getObject()) {} //------------------------------------------------------------------------------ @@ -1869,7 +1893,7 @@ PyInsertionPoint::PyInsertionPoint(PyOperationBase &beforeOperationBase) void PyInsertionPoint::insert(PyOperationBase &operationBase) { PyOperation &operation = operationBase.getOperation(); if (operation.isAttached()) - throw py::value_error( + throw nb::value_error( "Attempt to insert operation that is already attached"); block.getParentOperation()->checkValid(); MlirOperation beforeOp = {nullptr}; @@ -1882,7 +1906,7 @@ void PyInsertionPoint::insert(PyOperationBase &operationBase) { // already end in a known terminator (violating this will cause assertion // failures later). if (!mlirOperationIsNull(mlirBlockGetTerminator(block.get()))) { - throw py::index_error("Cannot insert operation at the end of a block " + throw nb::index_error("Cannot insert operation at the end of a block " "that already has a terminator. Did you mean to " "use 'InsertionPoint.at_block_terminator(block)' " "versus 'InsertionPoint(block)'?"); @@ -1908,19 +1932,19 @@ PyInsertionPoint PyInsertionPoint::atBlockBegin(PyBlock &block) { PyInsertionPoint PyInsertionPoint::atBlockTerminator(PyBlock &block) { MlirOperation terminator = mlirBlockGetTerminator(block.get()); if (mlirOperationIsNull(terminator)) - throw py::value_error("Block has no terminator"); + throw nb::value_error("Block has no terminator"); PyOperationRef terminatorOpRef = PyOperation::forOperation( block.getParentOperation()->getContext(), terminator); return PyInsertionPoint{block, std::move(terminatorOpRef)}; } -py::object PyInsertionPoint::contextEnter() { - return PyThreadContextEntry::pushInsertionPoint(*this); +nb::object PyInsertionPoint::contextEnter(nb::object insertPoint) { + return PyThreadContextEntry::pushInsertionPoint(insertPoint); } -void PyInsertionPoint::contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { +void PyInsertionPoint::contextExit(const nb::object &excType, + const nb::object &excVal, + const nb::object &excTb) { PyThreadContextEntry::popInsertionPoint(*this); } @@ -1932,14 +1956,14 @@ bool PyAttribute::operator==(const PyAttribute &other) const { return mlirAttributeEqual(attr, other.attr); } -py::object PyAttribute::getCapsule() { - return py::reinterpret_steal(mlirPythonAttributeToCapsule(*this)); +nb::object PyAttribute::getCapsule() { + return nb::steal(mlirPythonAttributeToCapsule(*this)); } -PyAttribute PyAttribute::createFromCapsule(py::object capsule) { +PyAttribute PyAttribute::createFromCapsule(nb::object capsule) { MlirAttribute rawAttr = mlirPythonCapsuleToAttribute(capsule.ptr()); if (mlirAttributeIsNull(rawAttr)) - throw py::error_already_set(); + throw nb::python_error(); return PyAttribute( PyMlirContext::forContext(mlirAttributeGetContext(rawAttr)), rawAttr); } @@ -1964,14 +1988,14 @@ bool PyType::operator==(const PyType &other) const { return mlirTypeEqual(type, other.type); } -py::object PyType::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeToCapsule(*this)); +nb::object PyType::getCapsule() { + return nb::steal(mlirPythonTypeToCapsule(*this)); } -PyType PyType::createFromCapsule(py::object capsule) { +PyType PyType::createFromCapsule(nb::object capsule) { MlirType rawType = mlirPythonCapsuleToType(capsule.ptr()); if (mlirTypeIsNull(rawType)) - throw py::error_already_set(); + throw nb::python_error(); return PyType(PyMlirContext::forContext(mlirTypeGetContext(rawType)), rawType); } @@ -1980,14 +2004,14 @@ PyType PyType::createFromCapsule(py::object capsule) { // PyTypeID. //------------------------------------------------------------------------------ -py::object PyTypeID::getCapsule() { - return py::reinterpret_steal(mlirPythonTypeIDToCapsule(*this)); +nb::object PyTypeID::getCapsule() { + return nb::steal(mlirPythonTypeIDToCapsule(*this)); } -PyTypeID PyTypeID::createFromCapsule(py::object capsule) { +PyTypeID PyTypeID::createFromCapsule(nb::object capsule) { MlirTypeID mlirTypeID = mlirPythonCapsuleToTypeID(capsule.ptr()); if (mlirTypeIDIsNull(mlirTypeID)) - throw py::error_already_set(); + throw nb::python_error(); return PyTypeID(mlirTypeID); } bool PyTypeID::operator==(const PyTypeID &other) const { @@ -1998,36 +2022,36 @@ bool PyTypeID::operator==(const PyTypeID &other) const { // PyValue and subclasses. //------------------------------------------------------------------------------ -pybind11::object PyValue::getCapsule() { - return py::reinterpret_steal(mlirPythonValueToCapsule(get())); +nb::object PyValue::getCapsule() { + return nb::steal(mlirPythonValueToCapsule(get())); } -pybind11::object PyValue::maybeDownCast() { +nb::object PyValue::maybeDownCast() { MlirType type = mlirValueGetType(get()); MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional valueCaster = + std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); - // py::return_value_policy::move means use std::move to move the return value + // nb::rv_policy::move means use std::move to move the return value // contents into a new instance that will be owned by Python. - py::object thisObj = py::cast(this, py::return_value_policy::move); + nb::object thisObj = nb::cast(this, nb::rv_policy::move); if (!valueCaster) return thisObj; return valueCaster.value()(thisObj); } -PyValue PyValue::createFromCapsule(pybind11::object capsule) { +PyValue PyValue::createFromCapsule(nb::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) - throw py::error_already_set(); + throw nb::python_error(); MlirOperation owner; if (mlirValueIsAOpResult(value)) owner = mlirOpResultGetOwner(value); if (mlirValueIsABlockArgument(value)) owner = mlirBlockGetParentOperation(mlirBlockArgumentGetOwner(value)); if (mlirOperationIsNull(owner)) - throw py::error_already_set(); + throw nb::python_error(); MlirContext ctx = mlirOperationGetContext(owner); PyOperationRef ownerRef = PyOperation::forOperation(PyMlirContext::forContext(ctx), owner); @@ -2042,16 +2066,17 @@ PySymbolTable::PySymbolTable(PyOperationBase &operation) : operation(operation.getOperation().getRef()) { symbolTable = mlirSymbolTableCreate(operation.getOperation().get()); if (mlirSymbolTableIsNull(symbolTable)) { - throw py::cast_error("Operation is not a Symbol Table."); + throw nb::type_error("Operation is not a Symbol Table."); } } -py::object PySymbolTable::dunderGetItem(const std::string &name) { +nb::object PySymbolTable::dunderGetItem(const std::string &name) { operation->checkValid(); MlirOperation symbol = mlirSymbolTableLookup( symbolTable, mlirStringRefCreate(name.data(), name.length())); if (mlirOperationIsNull(symbol)) - throw py::key_error("Symbol '" + name + "' not in the symbol table."); + throw nb::key_error( + ("Symbol '" + name + "' not in the symbol table.").c_str()); return PyOperation::forOperation(operation->getContext(), symbol, operation.getObject()) @@ -2069,8 +2094,8 @@ void PySymbolTable::erase(PyOperationBase &symbol) { } void PySymbolTable::dunderDel(const std::string &name) { - py::object operation = dunderGetItem(name); - erase(py::cast(operation)); + nb::object operation = dunderGetItem(name); + erase(nb::cast(operation)); } MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { @@ -2079,7 +2104,7 @@ MlirAttribute PySymbolTable::insert(PyOperationBase &symbol) { MlirAttribute symbolAttr = mlirOperationGetAttributeByName( symbol.getOperation().get(), mlirSymbolTableGetSymbolAttributeName()); if (mlirAttributeIsNull(symbolAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return mlirSymbolTableInsert(symbolTable, symbol.getOperation().get()); } @@ -2091,7 +2116,7 @@ MlirAttribute PySymbolTable::getSymbolName(PyOperationBase &symbol) { MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); return existingNameAttr; } @@ -2104,7 +2129,7 @@ void PySymbolTable::setSymbolName(PyOperationBase &symbol, MlirAttribute existingNameAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingNameAttr)) - throw py::value_error("Expected operation to have a symbol name."); + throw nb::value_error("Expected operation to have a symbol name."); MlirAttribute newNameAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(name)); mlirOperationSetAttributeByName(operation.get(), attrName, newNameAttr); @@ -2117,7 +2142,7 @@ MlirAttribute PySymbolTable::getVisibility(PyOperationBase &symbol) { MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); return existingVisAttr; } @@ -2125,7 +2150,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, const std::string &visibility) { if (visibility != "public" && visibility != "private" && visibility != "nested") - throw py::value_error( + throw nb::value_error( "Expected visibility to be 'public', 'private' or 'nested'"); PyOperation &operation = symbol.getOperation(); operation.checkValid(); @@ -2133,7 +2158,7 @@ void PySymbolTable::setVisibility(PyOperationBase &symbol, MlirAttribute existingVisAttr = mlirOperationGetAttributeByName(operation.get(), attrName); if (mlirAttributeIsNull(existingVisAttr)) - throw py::value_error("Expected operation to have a symbol visibility."); + throw nb::value_error("Expected operation to have a symbol visibility."); MlirAttribute newVisAttr = mlirStringAttrGet(operation.getContext()->get(), toMlirStringRef(visibility)); mlirOperationSetAttributeByName(operation.get(), attrName, newVisAttr); @@ -2148,20 +2173,20 @@ void PySymbolTable::replaceAllSymbolUses(const std::string &oldSymbol, toMlirStringRef(oldSymbol), toMlirStringRef(newSymbol), from.getOperation()))) - throw py::value_error("Symbol rename failed"); + throw nb::value_error("Symbol rename failed"); } void PySymbolTable::walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - py::object callback) { + nb::object callback) { PyOperation &fromOperation = from.getOperation(); fromOperation.checkValid(); struct UserData { PyMlirContextRef context; - py::object callback; + nb::object callback; bool gotException; std::string exceptionWhat; - py::object exceptionType; + nb::object exceptionType; }; UserData userData{ fromOperation.getContext(), std::move(callback), false, {}, {}}; @@ -2175,10 +2200,10 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, return; try { calleeUserData->callback(pyFoundOp.getObject(), isVisible); - } catch (py::error_already_set &e) { + } catch (nb::python_error &e) { calleeUserData->gotException = true; calleeUserData->exceptionWhat = e.what(); - calleeUserData->exceptionType = e.type(); + calleeUserData->exceptionType = nb::borrow(e.type()); } }, static_cast(&userData)); @@ -2200,7 +2225,7 @@ class PyConcreteValue : public PyValue { // IsAFunctionTy isaFunction // const char *pyClassName // and redefine bindDerived. - using ClassTy = py::class_; + using ClassTy = nb::class_; using IsAFunctionTy = bool (*)(MlirValue); PyConcreteValue() = default; @@ -2213,25 +2238,26 @@ class PyConcreteValue : public PyValue { /// type mismatches. static MlirValue castFrom(PyValue &orig) { if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = py::repr(py::cast(orig)).cast(); - throw py::value_error((Twine("Cannot cast value to ") + + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + DerivedTy::pyClassName + " (from " + origRepr + ")") - .str()); + .str() + .c_str()); } return orig.get(); } /// Binds the Python module objects to functions of this class. - static void bind(py::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, py::module_local()); - cls.def(py::init(), py::keep_alive<0, 1>(), py::arg("value")); + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); cls.def_static( "isinstance", [](PyValue &otherValue) -> bool { return DerivedTy::isaFunction(otherValue); }, - py::arg("other_value")); + nb::arg("other_value")); cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); @@ -2249,11 +2275,11 @@ class PyBlockArgument : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyBlockArgument &self) { + c.def_prop_ro("owner", [](PyBlockArgument &self) { return PyBlock(self.getParentOperation(), mlirBlockArgumentGetOwner(self.get())); }); - c.def_property_readonly("arg_number", [](PyBlockArgument &self) { + c.def_prop_ro("arg_number", [](PyBlockArgument &self) { return mlirBlockArgumentGetArgNumber(self.get()); }); c.def( @@ -2261,7 +2287,7 @@ class PyBlockArgument : public PyConcreteValue { [](PyBlockArgument &self, PyType type) { return mlirBlockArgumentSetType(self.get(), type); }, - py::arg("type")); + nb::arg("type")); } }; @@ -2273,14 +2299,14 @@ class PyOpResult : public PyConcreteValue { using PyConcreteValue::PyConcreteValue; static void bindDerived(ClassTy &c) { - c.def_property_readonly("owner", [](PyOpResult &self) { + c.def_prop_ro("owner", [](PyOpResult &self) { assert( mlirOperationEqual(self.getParentOperation()->get(), mlirOpResultGetOwner(self.get())) && "expected the owner of the value in Python to match that in the IR"); return self.getParentOperation().getObject(); }); - c.def_property_readonly("result_number", [](PyOpResult &self) { + c.def_prop_ro("result_number", [](PyOpResult &self) { return mlirOpResultGetResultNumber(self.get()); }); } @@ -2317,7 +2343,7 @@ class PyBlockArgumentList operation(std::move(operation)), block(block) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyBlockArgumentList &self) { + c.def_prop_ro("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); }); } @@ -2422,10 +2448,10 @@ class PyOpResultList : public Sliceable { operation(std::move(operation)) {} static void bindDerived(ClassTy &c) { - c.def_property_readonly("types", [](PyOpResultList &self) { + c.def_prop_ro("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); }); - c.def_property_readonly("owner", [](PyOpResultList &self) { + c.def_prop_ro("owner", [](PyOpResultList &self) { return self.operation->createOpView(); }); } @@ -2508,14 +2534,14 @@ class PyOpAttributeMap { MlirAttribute attr = mlirOperationGetAttributeByName(operation->get(), toMlirStringRef(name)); if (mlirAttributeIsNull(attr)) { - throw py::key_error("attempt to access a non-existent attribute"); + throw nb::key_error("attempt to access a non-existent attribute"); } return attr; } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { if (index < 0 || index >= dunderLen()) { - throw py::index_error("attempt to access out of bounds attribute"); + throw nb::index_error("attempt to access out of bounds attribute"); } MlirNamedAttribute namedAttr = mlirOperationGetAttribute(operation->get(), index); @@ -2534,7 +2560,7 @@ class PyOpAttributeMap { int removed = mlirOperationRemoveAttributeByName(operation->get(), toMlirStringRef(name)); if (!removed) - throw py::key_error("attempt to delete a non-existent attribute"); + throw nb::key_error("attempt to delete a non-existent attribute"); } intptr_t dunderLen() { @@ -2546,8 +2572,8 @@ class PyOpAttributeMap { operation->get(), toMlirStringRef(name))); } - static void bind(py::module &m) { - py::class_(m, "OpAttributeMap", py::module_local()) + static void bind(nb::module_ &m) { + nb::class_(m, "OpAttributeMap") .def("__contains__", &PyOpAttributeMap::dunderContains) .def("__len__", &PyOpAttributeMap::dunderLen) .def("__getitem__", &PyOpAttributeMap::dunderGetItemNamed) @@ -2566,21 +2592,21 @@ class PyOpAttributeMap { // Populates the core exports of the 'ir' submodule. //------------------------------------------------------------------------------ -void mlir::python::populateIRCore(py::module &m) { +void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- - py::enum_(m, "DiagnosticSeverity", py::module_local()) + nb::enum_(m, "DiagnosticSeverity") .value("ERROR", MlirDiagnosticError) .value("WARNING", MlirDiagnosticWarning) .value("NOTE", MlirDiagnosticNote) .value("REMARK", MlirDiagnosticRemark); - py::enum_(m, "WalkOrder", py::module_local()) + nb::enum_(m, "WalkOrder") .value("PRE_ORDER", MlirWalkPreOrder) .value("POST_ORDER", MlirWalkPostOrder); - py::enum_(m, "WalkResult", py::module_local()) + nb::enum_(m, "WalkResult") .value("ADVANCE", MlirWalkResultAdvance) .value("INTERRUPT", MlirWalkResultInterrupt) .value("SKIP", MlirWalkResultSkip); @@ -2588,33 +2614,37 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Diagnostics. //---------------------------------------------------------------------------- - py::class_(m, "Diagnostic", py::module_local()) - .def_property_readonly("severity", &PyDiagnostic::getSeverity) - .def_property_readonly("location", &PyDiagnostic::getLocation) - .def_property_readonly("message", &PyDiagnostic::getMessage) - .def_property_readonly("notes", &PyDiagnostic::getNotes) - .def("__str__", [](PyDiagnostic &self) -> py::str { + nb::class_(m, "Diagnostic") + .def_prop_ro("severity", &PyDiagnostic::getSeverity) + .def_prop_ro("location", &PyDiagnostic::getLocation) + .def_prop_ro("message", &PyDiagnostic::getMessage) + .def_prop_ro("notes", &PyDiagnostic::getNotes) + .def("__str__", [](PyDiagnostic &self) -> nb::str { if (!self.isValid()) - return ""; + return nb::str(""); return self.getMessage(); }); - py::class_(m, "DiagnosticInfo", - py::module_local()) - .def(py::init<>([](PyDiagnostic diag) { return diag.getInfo(); })) - .def_readonly("severity", &PyDiagnostic::DiagnosticInfo::severity) - .def_readonly("location", &PyDiagnostic::DiagnosticInfo::location) - .def_readonly("message", &PyDiagnostic::DiagnosticInfo::message) - .def_readonly("notes", &PyDiagnostic::DiagnosticInfo::notes) + nb::class_(m, "DiagnosticInfo") + .def("__init__", + [](PyDiagnostic::DiagnosticInfo &self, PyDiagnostic diag) { + new (&self) PyDiagnostic::DiagnosticInfo(diag.getInfo()); + }) + .def_ro("severity", &PyDiagnostic::DiagnosticInfo::severity) + .def_ro("location", &PyDiagnostic::DiagnosticInfo::location) + .def_ro("message", &PyDiagnostic::DiagnosticInfo::message) + .def_ro("notes", &PyDiagnostic::DiagnosticInfo::notes) .def("__str__", [](PyDiagnostic::DiagnosticInfo &self) { return self.message; }); - py::class_(m, "DiagnosticHandler", py::module_local()) + nb::class_(m, "DiagnosticHandler") .def("detach", &PyDiagnosticHandler::detach) - .def_property_readonly("attached", &PyDiagnosticHandler::isAttached) - .def_property_readonly("had_error", &PyDiagnosticHandler::getHadError) + .def_prop_ro("attached", &PyDiagnosticHandler::isAttached) + .def_prop_ro("had_error", &PyDiagnosticHandler::getHadError) .def("__enter__", &PyDiagnosticHandler::contextEnter) - .def("__exit__", &PyDiagnosticHandler::contextExit); + .def("__exit__", &PyDiagnosticHandler::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()); //---------------------------------------------------------------------------- // Mapping of MlirContext. @@ -2622,8 +2652,12 @@ void mlir::python::populateIRCore(py::module &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- - py::class_(m, "_BaseContext", py::module_local()) - .def(py::init<>(&PyMlirContext::createNewContextForInit)) + nb::class_(m, "_BaseContext") + .def("__init__", + [](PyMlirContext &self) { + MlirContext context = mlirContextCreateWithThreading(false); + new (&self) PyMlirContext(context); + }) .def_static("_get_live_count", &PyMlirContext::getLiveCount) .def("_get_context_again", [](PyMlirContext &self) { @@ -2635,28 +2669,28 @@ void mlir::python::populateIRCore(py::module &m) { &PyMlirContext::getLiveOperationObjects) .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) .def("_clear_live_operations_inside", - py::overload_cast( + nb::overload_cast( &PyMlirContext::clearOperationsInside)) .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyMlirContext::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) - .def("__exit__", &PyMlirContext::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyMlirContext::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *context = PyThreadContextEntry::getDefaultContext(); if (!context) - return py::none().cast(); - return py::cast(context); + return nb::none(); + return nb::cast(context); }, "Gets the Context bound to the current thread or raises ValueError") - .def_property_readonly( + .def_prop_ro( "dialects", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Gets a container for accessing dialects by name") - .def_property_readonly( + .def_prop_ro( "d", [](PyMlirContext &self) { return PyDialects(self.getRef()); }, "Alias for 'dialect'") .def( @@ -2665,14 +2699,14 @@ void mlir::python::populateIRCore(py::module &m) { MlirDialect dialect = mlirContextGetOrLoadDialect( self.get(), {name.data(), name.size()}); if (mlirDialectIsNull(dialect)) { - throw py::value_error( - (Twine("Dialect '") + name + "' not found").str()); + throw nb::value_error( + (Twine("Dialect '") + name + "' not found").str().c_str()); } return PyDialectDescriptor(self.getRef(), dialect); }, - py::arg("dialect_name"), + nb::arg("dialect_name"), "Gets or loads a dialect by name, returning its descriptor object") - .def_property( + .def_prop_rw( "allow_unregistered_dialects", [](PyMlirContext &self) -> bool { return mlirContextGetAllowUnregisteredDialects(self.get()); @@ -2681,32 +2715,32 @@ void mlir::python::populateIRCore(py::module &m) { mlirContextSetAllowUnregisteredDialects(self.get(), value); }) .def("attach_diagnostic_handler", &PyMlirContext::attachDiagnosticHandler, - py::arg("callback"), + nb::arg("callback"), "Attaches a diagnostic handler that will receive callbacks") .def( "enable_multithreading", [](PyMlirContext &self, bool enable) { mlirContextEnableMultithreading(self.get(), enable); }, - py::arg("enable")) + nb::arg("enable")) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { return mlirContextIsRegisteredOperation( self.get(), MlirStringRef{name.data(), name.size()}); }, - py::arg("operation_name")) + nb::arg("operation_name")) .def( "append_dialect_registry", [](PyMlirContext &self, PyDialectRegistry ®istry) { mlirContextAppendDialectRegistry(self.get(), registry); }, - py::arg("registry")) - .def_property("emit_error_diagnostics", nullptr, - &PyMlirContext::setEmitErrorDiagnostics, - "Emit error diagnostics to diagnostic handlers. By default " - "error diagnostics are captured and reported through " - "MLIRError exceptions.") + nb::arg("registry")) + .def_prop_rw("emit_error_diagnostics", nullptr, + &PyMlirContext::setEmitErrorDiagnostics, + "Emit error diagnostics to diagnostic handlers. By default " + "error diagnostics are captured and reported through " + "MLIRError exceptions.") .def("load_all_available_dialects", [](PyMlirContext &self) { mlirContextLoadAllAvailableDialects(self.get()); }); @@ -2714,13 +2748,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyDialectDescriptor //---------------------------------------------------------------------------- - py::class_(m, "DialectDescriptor", py::module_local()) - .def_property_readonly("namespace", - [](PyDialectDescriptor &self) { - MlirStringRef ns = - mlirDialectGetNamespace(self.get()); - return py::str(ns.data, ns.length); - }) + nb::class_(m, "DialectDescriptor") + .def_prop_ro("namespace", + [](PyDialectDescriptor &self) { + MlirStringRef ns = mlirDialectGetNamespace(self.get()); + return nb::str(ns.data, ns.length); + }) .def("__repr__", [](PyDialectDescriptor &self) { MlirStringRef ns = mlirDialectGetNamespace(self.get()); std::string repr("(m, "Dialects", py::module_local()) + nb::class_(m, "Dialects") .def("__getitem__", [=](PyDialects &self, std::string keyName) { MlirDialect dialect = self.getDialectForKey(keyName, /*attrError=*/false); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(keyName, std::move(descriptor)); }) .def("__getattr__", [=](PyDialects &self, std::string attrName) { MlirDialect dialect = self.getDialectForKey(attrName, /*attrError=*/true); - py::object descriptor = - py::cast(PyDialectDescriptor{self.getContext(), dialect}); + nb::object descriptor = + nb::cast(PyDialectDescriptor{self.getContext(), dialect}); return createCustomDialectWrapper(attrName, std::move(descriptor)); }); //---------------------------------------------------------------------------- // Mapping of PyDialect //---------------------------------------------------------------------------- - py::class_(m, "Dialect", py::module_local()) - .def(py::init(), py::arg("descriptor")) - .def_property_readonly( - "descriptor", [](PyDialect &self) { return self.getDescriptor(); }) - .def("__repr__", [](py::object self) { + nb::class_(m, "Dialect") + .def(nb::init(), nb::arg("descriptor")) + .def_prop_ro("descriptor", + [](PyDialect &self) { return self.getDescriptor(); }) + .def("__repr__", [](nb::object self) { auto clazz = self.attr("__class__"); - return py::str(""); + return nb::str(""); }); //---------------------------------------------------------------------------- // Mapping of PyDialectRegistry //---------------------------------------------------------------------------- - py::class_(m, "DialectRegistry", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyDialectRegistry::getCapsule) + nb::class_(m, "DialectRegistry") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyDialectRegistry::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyDialectRegistry::createFromCapsule) - .def(py::init<>()); + .def(nb::init<>()); //---------------------------------------------------------------------------- // Mapping of Location //---------------------------------------------------------------------------- - py::class_(m, "Location", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) + nb::class_(m, "Location") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyLocation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyLocation::createFromCapsule) .def("__enter__", &PyLocation::contextEnter) - .def("__exit__", &PyLocation::contextExit) + .def("__exit__", &PyLocation::contextExit, nb::arg("exc_type").none(), + nb::arg("exc_value").none(), nb::arg("traceback").none()) .def("__eq__", [](PyLocation &self, PyLocation &other) -> bool { return mlirLocationEqual(self, other); }) - .def("__eq__", [](PyLocation &self, py::object other) { return false; }) - .def_property_readonly_static( + .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw py::value_error("No current Location"); + throw nb::value_error("No current Location"); return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -2801,14 +2834,14 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationUnknownGet(context->get())); }, - py::arg("context") = py::none(), + nb::arg("context").none() = nb::none(), "Gets a Location representing an unknown location") .def_static( "callsite", [](PyLocation callee, const std::vector &frames, DefaultingPyMlirContext context) { if (frames.empty()) - throw py::value_error("No caller frames provided"); + throw nb::value_error("No caller frames provided"); MlirLocation caller = frames.back().get(); for (const PyLocation &frame : llvm::reverse(llvm::ArrayRef(frames).drop_back())) @@ -2816,7 +2849,8 @@ void mlir::python::populateIRCore(py::module &m) { return PyLocation(context->getRef(), mlirLocationCallSiteGet(callee.get(), caller)); }, - py::arg("callee"), py::arg("frames"), py::arg("context") = py::none(), + nb::arg("callee"), nb::arg("frames"), + nb::arg("context").none() = nb::none(), kContextGetCallSiteLocationDocstring) .def_static( "file", @@ -2827,8 +2861,9 @@ void mlir::python::populateIRCore(py::module &m) { mlirLocationFileLineColGet( context->get(), toMlirStringRef(filename), line, col)); }, - py::arg("filename"), py::arg("line"), py::arg("col"), - py::arg("context") = py::none(), kContextGetFileLocationDocstring) + nb::arg("filename"), nb::arg("line"), nb::arg("col"), + nb::arg("context").none() = nb::none(), + kContextGetFileLocationDocstring) .def_static( "fused", [](const std::vector &pyLocations, @@ -2843,8 +2878,9 @@ void mlir::python::populateIRCore(py::module &m) { metadata ? metadata->get() : MlirAttribute{0}); return PyLocation(context->getRef(), location); }, - py::arg("locations"), py::arg("metadata") = py::none(), - py::arg("context") = py::none(), kContextGetFusedLocationDocstring) + nb::arg("locations"), nb::arg("metadata").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetFusedLocationDocstring) .def_static( "name", [](std::string name, std::optional childLoc, @@ -2856,21 +2892,22 @@ void mlir::python::populateIRCore(py::module &m) { childLoc ? childLoc->get() : mlirLocationUnknownGet(context->get()))); }, - py::arg("name"), py::arg("childLoc") = py::none(), - py::arg("context") = py::none(), kContextGetNameLocationDocString) + nb::arg("name"), nb::arg("childLoc").none() = nb::none(), + nb::arg("context").none() = nb::none(), + kContextGetNameLocationDocString) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { return PyLocation(context->getRef(), mlirLocationFromAttribute(attribute)); }, - py::arg("attribute"), py::arg("context") = py::none(), + nb::arg("attribute"), nb::arg("context").none() = nb::none(), "Gets a Location from a LocationAttr") - .def_property_readonly( + .def_prop_ro( "context", [](PyLocation &self) { return self.getContext().getObject(); }, "Context that owns the Location") - .def_property_readonly( + .def_prop_ro( "attr", [](PyLocation &self) { return mlirLocationGetAttribute(self); }, "Get the underlying LocationAttr") @@ -2879,7 +2916,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyLocation &self, std::string message) { mlirEmitError(self, message.c_str()); }, - py::arg("message"), "Emits an error at this location") + nb::arg("message"), "Emits an error at this location") .def("__repr__", [](PyLocation &self) { PyPrintAccumulator printAccum; mlirLocationPrint(self, printAccum.getCallback(), @@ -2890,8 +2927,8 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Module //---------------------------------------------------------------------------- - py::class_(m, "Module", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) + nb::class_(m, "Module", nb::is_weak_referenceable()) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) .def_static( "parse", @@ -2903,7 +2940,19 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse module assembly", errors.take()); return PyModule::forModule(module).releaseObject(); }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) + .def_static( + "parse", + [](nb::bytes moduleAsm, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParse( + context->get(), toMlirStringRef(moduleAsm)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) .def_static( "create", @@ -2911,12 +2960,12 @@ void mlir::python::populateIRCore(py::module &m) { MlirModule module = mlirModuleCreateEmpty(loc); return PyModule::forModule(module).releaseObject(); }, - py::arg("loc") = py::none(), "Creates an empty module") - .def_property_readonly( + nb::arg("loc").none() = nb::none(), "Creates an empty module") + .def_prop_ro( "context", [](PyModule &self) { return self.getContext().getObject(); }, "Context that created the Module") - .def_property_readonly( + .def_prop_ro( "operation", [](PyModule &self) { return PyOperation::forOperation(self.getContext(), @@ -2925,7 +2974,7 @@ void mlir::python::populateIRCore(py::module &m) { .releaseObject(); }, "Accesses the module as an operation") - .def_property_readonly( + .def_prop_ro( "body", [](PyModule &self) { PyOperationRef moduleOp = PyOperation::forOperation( @@ -2943,7 +2992,7 @@ void mlir::python::populateIRCore(py::module &m) { kDumpDocstring) .def( "__str__", - [](py::object self) { + [](nb::object self) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, @@ -2952,27 +3001,26 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Operation. //---------------------------------------------------------------------------- - py::class_(m, "_OperationBase", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - [](PyOperationBase &self) { - return self.getOperation().getCapsule(); - }) + nb::class_(m, "_OperationBase") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + [](PyOperationBase &self) { + return self.getOperation().getCapsule(); + }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { return &self.getOperation() == &other.getOperation(); }) .def("__eq__", - [](PyOperationBase &self, py::object other) { return false; }) + [](PyOperationBase &self, nb::object other) { return false; }) .def("__hash__", [](PyOperationBase &self) { return static_cast(llvm::hash_value(&self.getOperation())); }) - .def_property_readonly("attributes", - [](PyOperationBase &self) { - return PyOpAttributeMap( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("attributes", + [](PyOperationBase &self) { + return PyOpAttributeMap(self.getOperation().getRef()); + }) + .def_prop_ro( "context", [](PyOperationBase &self) { PyOperation &concreteOperation = self.getOperation(); @@ -2980,46 +3028,44 @@ void mlir::python::populateIRCore(py::module &m) { return concreteOperation.getContext().getObject(); }, "Context that owns the Operation") - .def_property_readonly("name", - [](PyOperationBase &self) { - auto &concreteOperation = self.getOperation(); - concreteOperation.checkValid(); - MlirOperation operation = - concreteOperation.get(); - MlirStringRef name = mlirIdentifierStr( - mlirOperationGetName(operation)); - return py::str(name.data, name.length); - }) - .def_property_readonly("operands", - [](PyOperationBase &self) { - return PyOpOperandList( - self.getOperation().getRef()); - }) - .def_property_readonly("regions", - [](PyOperationBase &self) { - return PyRegionList( - self.getOperation().getRef()); - }) - .def_property_readonly( + .def_prop_ro("name", + [](PyOperationBase &self) { + auto &concreteOperation = self.getOperation(); + concreteOperation.checkValid(); + MlirOperation operation = concreteOperation.get(); + MlirStringRef name = + mlirIdentifierStr(mlirOperationGetName(operation)); + return nb::str(name.data, name.length); + }) + .def_prop_ro("operands", + [](PyOperationBase &self) { + return PyOpOperandList(self.getOperation().getRef()); + }) + .def_prop_ro("regions", + [](PyOperationBase &self) { + return PyRegionList(self.getOperation().getRef()); + }) + .def_prop_ro( "results", [](PyOperationBase &self) { return PyOpResultList(self.getOperation().getRef()); }, "Returns the list of Operation results.") - .def_property_readonly( + .def_prop_ro( "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); auto numResults = mlirOperationGetNumResults(operation); if (numResults != 1) { auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw py::value_error( + throw nb::value_error( (Twine("Cannot call .result on operation ") + StringRef(name.data, name.length) + " which has " + Twine(numResults) + " results (it is only valid for operations with a " "single result)") - .str()); + .str() + .c_str()); } return PyOpResult(operation.getRef(), mlirOperationGetResult(operation, 0)) @@ -3027,7 +3073,7 @@ void mlir::python::populateIRCore(py::module &m) { }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") - .def_property_readonly( + .def_prop_ro( "location", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); @@ -3036,14 +3082,13 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the source location the operation was defined or derived " "from.") - .def_property_readonly("parent", - [](PyOperationBase &self) -> py::object { - auto parent = - self.getOperation().getParentOperation(); - if (parent) - return parent->getObject(); - return py::none(); - }) + .def_prop_ro("parent", + [](PyOperationBase &self) -> nb::object { + auto parent = self.getOperation().getParentOperation(); + if (parent) + return parent->getObject(); + return nb::none(); + }) .def( "__str__", [](PyOperationBase &self) { @@ -3058,75 +3103,76 @@ void mlir::python::populateIRCore(py::module &m) { }, "Returns the assembly form of the operation.") .def("print", - py::overload_cast( + nb::overload_cast( &PyOperationBase::print), - py::arg("state"), py::arg("file") = py::none(), - py::arg("binary") = false, kOperationPrintStateDocstring) + nb::arg("state"), nb::arg("file").none() = nb::none(), + nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - py::overload_cast, bool, bool, bool, bool, - bool, py::object, bool, bool>( + nb::overload_cast, bool, bool, bool, bool, + bool, nb::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("file") = py::none(), - py::arg("binary") = false, py::arg("skip_regions") = false, - kOperationPrintDocstring) - .def("write_bytecode", &PyOperationBase::writeBytecode, py::arg("file"), - py::arg("desired_version") = py::none(), + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, + nb::arg("file").none() = nb::none(), nb::arg("binary") = false, + nb::arg("skip_regions") = false, kOperationPrintDocstring) + .def("write_bytecode", &PyOperationBase::writeBytecode, nb::arg("file"), + nb::arg("desired_version").none() = nb::none(), kOperationPrintBytecodeDocstring) .def("get_asm", &PyOperationBase::getAsm, // Careful: Lots of arguments must match up with get_asm method. - py::arg("binary") = false, - py::arg("large_elements_limit") = py::none(), - py::arg("enable_debug_info") = false, - py::arg("pretty_debug_info") = false, - py::arg("print_generic_op_form") = false, - py::arg("use_local_scope") = false, - py::arg("assume_verified") = false, py::arg("skip_regions") = false, + nb::arg("binary") = false, + nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("enable_debug_info") = false, + nb::arg("pretty_debug_info") = false, + nb::arg("print_generic_op_form") = false, + nb::arg("use_local_scope") = false, + nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, "Verify the operation. Raises MLIRError if verification fails, and " "returns true otherwise.") - .def("move_after", &PyOperationBase::moveAfter, py::arg("other"), + .def("move_after", &PyOperationBase::moveAfter, nb::arg("other"), "Puts self immediately after the other operation in its parent " "block.") - .def("move_before", &PyOperationBase::moveBefore, py::arg("other"), + .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), "Puts self immediately before the other operation in its parent " "block.") .def( "clone", - [](PyOperationBase &self, py::object ip) { + [](PyOperationBase &self, nb::object ip) { return self.getOperation().clone(ip); }, - py::arg("ip") = py::none()) + nb::arg("ip").none() = nb::none()) .def( "detach_from_parent", [](PyOperationBase &self) { PyOperation &operation = self.getOperation(); operation.checkValid(); if (!operation.isAttached()) - throw py::value_error("Detached operation has no parent."); + throw nb::value_error("Detached operation has no parent."); operation.detachFromParent(); return operation.createOpView(); }, "Detaches the operation from its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) - .def("walk", &PyOperationBase::walk, py::arg("callback"), - py::arg("walk_order") = MlirWalkPostOrder); - - py::class_(m, "Operation", py::module_local()) - .def_static("create", &PyOperation::create, py::arg("name"), - py::arg("results") = py::none(), - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = 0, - py::arg("loc") = py::none(), py::arg("ip") = py::none(), - py::arg("infer_type") = false, kOperationCreateDocstring) + .def("walk", &PyOperationBase::walk, nb::arg("callback"), + nb::arg("walk_order") = MlirWalkPostOrder); + + nb::class_(m, "Operation") + .def_static("create", &PyOperation::create, nb::arg("name"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), + nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, @@ -3134,16 +3180,15 @@ void mlir::python::populateIRCore(py::module &m) { return PyOperation::parse(context->getRef(), sourceStr, sourceName) ->createOpView(); }, - py::arg("source"), py::kw_only(), py::arg("source_name") = "", - py::arg("context") = py::none(), + nb::arg("source"), nb::kw_only(), nb::arg("source_name") = "", + nb::arg("context").none() = nb::none(), "Parses an operation. Supports both text assembly format and binary " "bytecode format.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyOperation::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyOperation::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) - .def_property_readonly("operation", [](py::object self) { return self; }) - .def_property_readonly("opview", &PyOperation::createOpView) - .def_property_readonly( + .def_prop_ro("operation", [](nb::object self) { return self; }) + .def_prop_ro("opview", &PyOperation::createOpView) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); @@ -3151,30 +3196,33 @@ void mlir::python::populateIRCore(py::module &m) { "Returns the list of Operation successors."); auto opViewClass = - py::class_(m, "OpView", py::module_local()) - .def(py::init(), py::arg("operation")) - .def_property_readonly("operation", &PyOpView::getOperationObject) - .def_property_readonly("opview", [](py::object self) { return self; }) + nb::class_(m, "OpView") + .def(nb::init(), nb::arg("operation")) + .def_prop_ro("operation", &PyOpView::getOperationObject) + .def_prop_ro("opview", [](nb::object self) { return self; }) .def( "__str__", - [](PyOpView &self) { return py::str(self.getOperationObject()); }) - .def_property_readonly( + [](PyOpView &self) { return nb::str(self.getOperationObject()); }) + .def_prop_ro( "successors", [](PyOperationBase &self) { return PyOpSuccessors(self.getOperation().getRef()); }, "Returns the list of Operation successors."); - opViewClass.attr("_ODS_REGIONS") = py::make_tuple(0, true); - opViewClass.attr("_ODS_OPERAND_SEGMENTS") = py::none(); - opViewClass.attr("_ODS_RESULT_SEGMENTS") = py::none(); + opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); + opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); + opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, py::arg("cls"), py::arg("results") = py::none(), - py::arg("operands") = py::none(), py::arg("attributes") = py::none(), - py::arg("successors") = py::none(), py::arg("regions") = py::none(), - py::arg("loc") = py::none(), py::arg("ip") = py::none(), + &PyOpView::buildGeneric, nb::arg("cls"), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), "Builds a specific, generated OpView based on class level attributes."); opViewClass.attr("parse") = classmethod( - [](const py::object &cls, const std::string &sourceStr, + [](const nb::object &cls, const std::string &sourceStr, const std::string &sourceName, DefaultingPyMlirContext context) { PyOperationRef parsed = PyOperation::parse(context->getRef(), sourceStr, sourceName); @@ -3185,30 +3233,30 @@ void mlir::python::populateIRCore(py::module &m) { // `OpView` subclasses, and is not intended to be used on `OpView` // directly. std::string clsOpName = - py::cast(cls.attr("OPERATION_NAME")); + nb::cast(cls.attr("OPERATION_NAME")); MlirStringRef identifier = mlirIdentifierStr(mlirOperationGetName(*parsed.get())); std::string_view parsedOpName(identifier.data, identifier.length); if (clsOpName != parsedOpName) throw MLIRError(Twine("Expected a '") + clsOpName + "' op, got: '" + parsedOpName + "'"); - return PyOpView::constructDerived(cls, *parsed.get()); + return PyOpView::constructDerived(cls, parsed.getObject()); }, - py::arg("cls"), py::arg("source"), py::kw_only(), - py::arg("source_name") = "", py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("source"), nb::kw_only(), + nb::arg("source_name") = "", nb::arg("context").none() = nb::none(), "Parses a specific, generated OpView based on class level attributes"); //---------------------------------------------------------------------------- // Mapping of PyRegion. //---------------------------------------------------------------------------- - py::class_(m, "Region", py::module_local()) - .def_property_readonly( + nb::class_(m, "Region") + .def_prop_ro( "blocks", [](PyRegion &self) { return PyBlockList(self.getParentOperation(), self.get()); }, "Returns a forward-optimized sequence of blocks.") - .def_property_readonly( + .def_prop_ro( "owner", [](PyRegion &self) { return self.getParentOperation()->createOpView(); @@ -3226,27 +3274,27 @@ void mlir::python::populateIRCore(py::module &m) { [](PyRegion &self, PyRegion &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyRegion &self, py::object &other) { return false; }); + .def("__eq__", [](PyRegion &self, nb::object &other) { return false; }); //---------------------------------------------------------------------------- // Mapping of PyBlock. //---------------------------------------------------------------------------- - py::class_(m, "Block", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) - .def_property_readonly( + nb::class_(m, "Block") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyBlock::getCapsule) + .def_prop_ro( "owner", [](PyBlock &self) { return self.getParentOperation()->createOpView(); }, "Returns the owning operation of this block.") - .def_property_readonly( + .def_prop_ro( "region", [](PyBlock &self) { MlirRegion region = mlirBlockGetParentRegion(self.get()); return PyRegion(self.getParentOperation(), region); }, "Returns the owning region of this block.") - .def_property_readonly( + .def_prop_ro( "arguments", [](PyBlock &self) { return PyBlockArgumentList(self.getParentOperation(), self.get()); @@ -3265,7 +3313,7 @@ void mlir::python::populateIRCore(py::module &m) { return mlirBlockEraseArgument(self.get(), index); }, "Erase the argument at 'index' and remove it from the argument list.") - .def_property_readonly( + .def_prop_ro( "operations", [](PyBlock &self) { return PyOperationList(self.getParentOperation(), self.get()); @@ -3273,15 +3321,15 @@ void mlir::python::populateIRCore(py::module &m) { "Returns a forward-optimized sequence of operations.") .def_static( "create_at_start", - [](PyRegion &parent, const py::list &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyRegion &parent, const nb::sequence &pyArgTypes, + const std::optional &pyArgLocs) { parent.checkValid(); MlirBlock block = createBlock(pyArgTypes, pyArgLocs); mlirRegionInsertOwnedBlock(parent, 0, block); return PyBlock(parent.getParentOperation(), block); }, - py::arg("parent"), py::arg("arg_types") = py::list(), - py::arg("arg_locs") = std::nullopt, + nb::arg("parent"), nb::arg("arg_types") = nb::list(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block at the beginning of the given " "region (with given argument types and locations).") .def( @@ -3295,28 +3343,32 @@ void mlir::python::populateIRCore(py::module &m) { "Append this block to a region, transferring ownership if necessary") .def( "create_before", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockBefore(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block before this block " "(with given argument types and locations).") .def( "create_after", - [](PyBlock &self, const py::args &pyArgTypes, - const std::optional &pyArgLocs) { + [](PyBlock &self, const nb::args &pyArgTypes, + const std::optional &pyArgLocs) { self.checkValid(); - MlirBlock block = createBlock(pyArgTypes, pyArgLocs); + MlirBlock block = + createBlock(nb::cast(pyArgTypes), pyArgLocs); MlirRegion region = mlirBlockGetParentRegion(self.get()); mlirRegionInsertOwnedBlockAfter(region, self.get(), block); return PyBlock(self.getParentOperation(), block); }, - py::arg("arg_locs") = std::nullopt, + nb::arg("arg_types"), nb::kw_only(), + nb::arg("arg_locs") = std::nullopt, "Creates and returns a new Block after this block " "(with given argument types and locations).") .def( @@ -3333,7 +3385,7 @@ void mlir::python::populateIRCore(py::module &m) { [](PyBlock &self, PyBlock &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyBlock &self, py::object &other) { return false; }) + .def("__eq__", [](PyBlock &self, nb::object &other) { return false; }) .def("__hash__", [](PyBlock &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3359,7 +3411,7 @@ void mlir::python::populateIRCore(py::module &m) { operation.getOperation().setAttached( self.getParentOperation().getObject()); }, - py::arg("operation"), + nb::arg("operation"), "Appends an operation to this block. If the operation is currently " "in another block, it will be moved."); @@ -3367,39 +3419,41 @@ void mlir::python::populateIRCore(py::module &m) { // Mapping of PyInsertionPoint. //---------------------------------------------------------------------------- - py::class_(m, "InsertionPoint", py::module_local()) - .def(py::init(), py::arg("block"), + nb::class_(m, "InsertionPoint") + .def(nb::init(), nb::arg("block"), "Inserts after the last operation but still inside the block.") .def("__enter__", &PyInsertionPoint::contextEnter) - .def("__exit__", &PyInsertionPoint::contextExit) - .def_property_readonly_static( + .def("__exit__", &PyInsertionPoint::contextExit, + nb::arg("exc_type").none(), nb::arg("exc_value").none(), + nb::arg("traceback").none()) + .def_prop_ro_static( "current", - [](py::object & /*class*/) { + [](nb::object & /*class*/) { auto *ip = PyThreadContextEntry::getDefaultInsertionPoint(); if (!ip) - throw py::value_error("No current InsertionPoint"); + throw nb::value_error("No current InsertionPoint"); return ip; }, "Gets the InsertionPoint bound to the current thread or raises " "ValueError if none has been set") - .def(py::init(), py::arg("beforeOperation"), + .def(nb::init(), nb::arg("beforeOperation"), "Inserts before a referenced operation.") .def_static("at_block_begin", &PyInsertionPoint::atBlockBegin, - py::arg("block"), "Inserts at the beginning of the block.") + nb::arg("block"), "Inserts at the beginning of the block.") .def_static("at_block_terminator", &PyInsertionPoint::atBlockTerminator, - py::arg("block"), "Inserts before the block terminator.") - .def("insert", &PyInsertionPoint::insert, py::arg("operation"), + nb::arg("block"), "Inserts before the block terminator.") + .def("insert", &PyInsertionPoint::insert, nb::arg("operation"), "Inserts an operation.") - .def_property_readonly( + .def_prop_ro( "block", [](PyInsertionPoint &self) { return self.getBlock(); }, "Returns the block that this InsertionPoint points to.") - .def_property_readonly( + .def_prop_ro( "ref_operation", - [](PyInsertionPoint &self) -> py::object { + [](PyInsertionPoint &self) -> nb::object { auto refOperation = self.getRefOperation(); if (refOperation) return refOperation->getObject(); - return py::none(); + return nb::none(); }, "The reference operation before which new operations are " "inserted, or None if the insertion point is at the end of " @@ -3408,13 +3462,12 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of PyAttribute. //---------------------------------------------------------------------------- - py::class_(m, "Attribute", py::module_local()) + nb::class_(m, "Attribute") // Delegate to the PyAttribute copy constructor, which will also lifetime // extend the backing context which owns the MlirAttribute. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed attribute to the generic Attribute") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyAttribute::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyAttribute::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyAttribute::createFromCapsule) .def_static( "parse", @@ -3426,24 +3479,24 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse attribute", errors.take()); return attr; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), "Parses an attribute from an assembly form. Raises an MLIRError on " "failure.") - .def_property_readonly( + .def_prop_ro( "context", [](PyAttribute &self) { return self.getContext().getObject(); }, "Context that owns the Attribute") - .def_property_readonly( - "type", [](PyAttribute &self) { return mlirAttributeGetType(self); }) + .def_prop_ro("type", + [](PyAttribute &self) { return mlirAttributeGetType(self); }) .def( "get_named", [](PyAttribute &self, std::string name) { return PyNamedAttribute(self, std::move(name)); }, - py::keep_alive<0, 1>(), "Binds a name to the attribute") + nb::keep_alive<0, 1>(), "Binds a name to the attribute") .def("__eq__", [](PyAttribute &self, PyAttribute &other) { return self == other; }) - .def("__eq__", [](PyAttribute &self, py::object &other) { return false; }) + .def("__eq__", [](PyAttribute &self, nb::object &other) { return false; }) .def("__hash__", [](PyAttribute &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3474,36 +3527,35 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( - "typeid", - [](PyAttribute &self) -> MlirTypeID { - MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); - assert(!mlirTypeIDIsNull(mlirTypeID) && - "mlirTypeID was expected to be non-null."); - return mlirTypeID; - }) + .def_prop_ro("typeid", + [](PyAttribute &self) -> MlirTypeID { + MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + return mlirTypeID; + }) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyAttribute &self) { MlirTypeID mlirTypeID = mlirAttributeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirAttributeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }); //---------------------------------------------------------------------------- // Mapping of PyNamedAttribute //---------------------------------------------------------------------------- - py::class_(m, "NamedAttribute", py::module_local()) + nb::class_(m, "NamedAttribute") .def("__repr__", [](PyNamedAttribute &self) { PyPrintAccumulator printAccum; printAccum.parts.append("NamedAttribute("); printAccum.parts.append( - py::str(mlirIdentifierStr(self.namedAttr.name).data, + nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length)); printAccum.parts.append("="); mlirAttributePrint(self.namedAttr.attribute, @@ -3512,28 +3564,28 @@ void mlir::python::populateIRCore(py::module &m) { printAccum.parts.append(")"); return printAccum.join(); }) - .def_property_readonly( + .def_prop_ro( "name", [](PyNamedAttribute &self) { - return py::str(mlirIdentifierStr(self.namedAttr.name).data, + return nb::str(mlirIdentifierStr(self.namedAttr.name).data, mlirIdentifierStr(self.namedAttr.name).length); }, "The name of the NamedAttribute binding") - .def_property_readonly( + .def_prop_ro( "attr", [](PyNamedAttribute &self) { return self.namedAttr.attribute; }, - py::keep_alive<0, 1>(), + nb::keep_alive<0, 1>(), "The underlying generic attribute of the NamedAttribute binding"); //---------------------------------------------------------------------------- // Mapping of PyType. //---------------------------------------------------------------------------- - py::class_(m, "Type", py::module_local()) + nb::class_(m, "Type") // Delegate to the PyType copy constructor, which will also lifetime // extend the backing context which owns the MlirType. - .def(py::init(), py::arg("cast_from_type"), + .def(nb::init(), nb::arg("cast_from_type"), "Casts the passed type to the generic Type") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyType::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyType::createFromCapsule) .def_static( "parse", @@ -3545,13 +3597,15 @@ void mlir::python::populateIRCore(py::module &m) { throw MLIRError("Unable to parse type", errors.take()); return type; }, - py::arg("asm"), py::arg("context") = py::none(), + nb::arg("asm"), nb::arg("context").none() = nb::none(), kContextParseTypeDocstring) - .def_property_readonly( + .def_prop_ro( "context", [](PyType &self) { return self.getContext().getObject(); }, "Context that owns the Type") .def("__eq__", [](PyType &self, PyType &other) { return self == other; }) - .def("__eq__", [](PyType &self, py::object &other) { return false; }) + .def( + "__eq__", [](PyType &self, nb::object &other) { return false; }, + nb::arg("other").none()) .def("__hash__", [](PyType &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3585,28 +3639,27 @@ void mlir::python::populateIRCore(py::module &m) { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); assert(!mlirTypeIDIsNull(mlirTypeID) && "mlirTypeID was expected to be non-null."); - std::optional typeCaster = + std::optional typeCaster = PyGlobals::get().lookupTypeCaster(mlirTypeID, mlirTypeGetDialect(self)); if (!typeCaster) - return py::cast(self); + return nb::cast(self); return typeCaster.value()(self); }) - .def_property_readonly("typeid", [](PyType &self) -> MlirTypeID { + .def_prop_ro("typeid", [](PyType &self) -> MlirTypeID { MlirTypeID mlirTypeID = mlirTypeGetTypeID(self); if (!mlirTypeIDIsNull(mlirTypeID)) return mlirTypeID; - auto origRepr = - pybind11::repr(pybind11::cast(self)).cast(); - throw py::value_error( - (origRepr + llvm::Twine(" has no typeid.")).str()); + auto origRepr = nb::cast(nb::repr(nb::cast(self))); + throw nb::value_error( + (origRepr + llvm::Twine(" has no typeid.")).str().c_str()); }); //---------------------------------------------------------------------------- // Mapping of PyTypeID. //---------------------------------------------------------------------------- - py::class_(m, "TypeID", py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) + nb::class_(m, "TypeID") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyTypeID::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyTypeID::createFromCapsule) // Note, this tests whether the underlying TypeIDs are the same, // not whether the wrapper MlirTypeIDs are the same, nor whether @@ -3614,7 +3667,7 @@ void mlir::python::populateIRCore(py::module &m) { .def("__eq__", [](PyTypeID &self, PyTypeID &other) { return self == other; }) .def("__eq__", - [](PyTypeID &self, const py::object &other) { return false; }) + [](PyTypeID &self, const nb::object &other) { return false; }) // Note, this gives the hash value of the underlying TypeID, not the // hash value of the Python object, nor the hash value of the // MlirTypeID wrapper. @@ -3625,20 +3678,20 @@ void mlir::python::populateIRCore(py::module &m) { //---------------------------------------------------------------------------- // Mapping of Value. //---------------------------------------------------------------------------- - py::class_(m, "Value", py::module_local()) - .def(py::init(), py::keep_alive<0, 1>(), py::arg("value")) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) + nb::class_(m, "Value") + .def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")) + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyValue::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyValue::createFromCapsule) - .def_property_readonly( + .def_prop_ro( "context", [](PyValue &self) { return self.getParentOperation()->getContext(); }, "Context in which the value lives.") .def( "dump", [](PyValue &self) { mlirValueDump(self.get()); }, kDumpDocstring) - .def_property_readonly( + .def_prop_ro( "owner", - [](PyValue &self) -> py::object { + [](PyValue &self) -> nb::object { MlirValue v = self.get(); if (mlirValueIsAOpResult(v)) { assert( @@ -3651,22 +3704,22 @@ void mlir::python::populateIRCore(py::module &m) { if (mlirValueIsABlockArgument(v)) { MlirBlock block = mlirBlockArgumentGetOwner(self.get()); - return py::cast(PyBlock(self.getParentOperation(), block)); + return nb::cast(PyBlock(self.getParentOperation(), block)); } assert(false && "Value must be a block argument or an op result"); - return py::none(); + return nb::none(); }) - .def_property_readonly("uses", - [](PyValue &self) { - return PyOpOperandIterator( - mlirValueGetFirstUse(self.get())); - }) + .def_prop_ro("uses", + [](PyValue &self) { + return PyOpOperandIterator( + mlirValueGetFirstUse(self.get())); + }) .def("__eq__", [](PyValue &self, PyValue &other) { return self.get().ptr == other.get().ptr; }) - .def("__eq__", [](PyValue &self, py::object other) { return false; }) + .def("__eq__", [](PyValue &self, nb::object other) { return false; }) .def("__hash__", [](PyValue &self) { return static_cast(llvm::hash_value(self.get().ptr)); @@ -3698,26 +3751,26 @@ void mlir::python::populateIRCore(py::module &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - py::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false) .def( "get_name", - [](PyValue &self, std::reference_wrapper state) { + [](PyValue &self, PyAsmState &state) { PyPrintAccumulator printAccum; - MlirAsmState valueState = state.get().get(); + MlirAsmState valueState = state.get(); mlirValuePrintAsOperand(self.get(), valueState, printAccum.getCallback(), printAccum.getUserData()); return printAccum.join(); }, - py::arg("state"), kGetNameAsOperand) - .def_property_readonly( - "type", [](PyValue &self) { return mlirValueGetType(self.get()); }) + nb::arg("state"), kGetNameAsOperand) + .def_prop_ro("type", + [](PyValue &self) { return mlirValueGetType(self.get()); }) .def( "set_type", [](PyValue &self, const PyType &type) { return mlirValueSetType(self.get(), type); }, - py::arg("type")) + nb::arg("type")) .def( "replace_all_uses_with", [](PyValue &self, PyValue &with) { @@ -3730,22 +3783,22 @@ void mlir::python::populateIRCore(py::module &m) { MlirOperation exceptedUser = exception.get(); mlirValueReplaceAllUsesExcept(self, with, 1, &exceptedUser); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def( "replace_all_uses_except", - [](MlirValue self, MlirValue with, py::list exceptions) { + [](MlirValue self, MlirValue with, nb::list exceptions) { // Convert Python list to a SmallVector of MlirOperations llvm::SmallVector exceptionOps; - for (py::handle exception : exceptions) { - exceptionOps.push_back(exception.cast().get()); + for (nb::handle exception : exceptions) { + exceptionOps.push_back(nb::cast(exception).get()); } mlirValueReplaceAllUsesExcept( self, with, static_cast(exceptionOps.size()), exceptionOps.data()); }, - py::arg("with"), py::arg("exceptions"), + nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, [](PyValue &self) { return self.maybeDownCast(); }); @@ -3753,20 +3806,20 @@ void mlir::python::populateIRCore(py::module &m) { PyOpResult::bind(m); PyOpOperand::bind(m); - py::class_(m, "AsmState", py::module_local()) - .def(py::init(), py::arg("value"), - py::arg("use_local_scope") = false) - .def(py::init(), py::arg("op"), - py::arg("use_local_scope") = false); + nb::class_(m, "AsmState") + .def(nb::init(), nb::arg("value"), + nb::arg("use_local_scope") = false) + .def(nb::init(), nb::arg("op"), + nb::arg("use_local_scope") = false); //---------------------------------------------------------------------------- // Mapping of SymbolTable. //---------------------------------------------------------------------------- - py::class_(m, "SymbolTable", py::module_local()) - .def(py::init()) + nb::class_(m, "SymbolTable") + .def(nb::init()) .def("__getitem__", &PySymbolTable::dunderGetItem) - .def("insert", &PySymbolTable::insert, py::arg("operation")) - .def("erase", &PySymbolTable::erase, py::arg("operation")) + .def("insert", &PySymbolTable::insert, nb::arg("operation")) + .def("erase", &PySymbolTable::erase, nb::arg("operation")) .def("__delitem__", &PySymbolTable::dunderDel) .def("__contains__", [](PySymbolTable &table, const std::string &name) { @@ -3775,19 +3828,19 @@ void mlir::python::populateIRCore(py::module &m) { }) // Static helpers. .def_static("set_symbol_name", &PySymbolTable::setSymbolName, - py::arg("symbol"), py::arg("name")) + nb::arg("symbol"), nb::arg("name")) .def_static("get_symbol_name", &PySymbolTable::getSymbolName, - py::arg("symbol")) + nb::arg("symbol")) .def_static("get_visibility", &PySymbolTable::getVisibility, - py::arg("symbol")) + nb::arg("symbol")) .def_static("set_visibility", &PySymbolTable::setVisibility, - py::arg("symbol"), py::arg("visibility")) + nb::arg("symbol"), nb::arg("visibility")) .def_static("replace_all_symbol_uses", - &PySymbolTable::replaceAllSymbolUses, py::arg("old_symbol"), - py::arg("new_symbol"), py::arg("from_op")) + &PySymbolTable::replaceAllSymbolUses, nb::arg("old_symbol"), + nb::arg("new_symbol"), nb::arg("from_op")) .def_static("walk_symbol_tables", &PySymbolTable::walkSymbolTables, - py::arg("from_op"), py::arg("all_sym_uses_visible"), - py::arg("callback")); + nb::arg("from_op"), nb::arg("all_sym_uses_visible"), + nb::arg("callback")); // Container bindings. PyBlockArgumentList::bind(m); @@ -3809,14 +3862,15 @@ void mlir::python::populateIRCore(py::module &m) { // Attribute builder getter. PyAttrBuilderMap::bind(m); - py::register_local_exception_translator([](std::exception_ptr p) { + nb::register_exception_translator([](const std::exception_ptr &p, + void *payload) { // We can't define exceptions with custom fields through pybind, so instead // the exception class is defined in python and imported here. try { if (p) std::rethrow_exception(p); } catch (const MLIRError &e) { - py::object obj = py::module_::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + nb::object obj = nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("MLIRError")(e.message, e.errorDiagnostics); PyErr_SetObject(PyExc_Exception, obj.ptr()); } diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index 54cfa5606..c339a93e3 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,12 +6,12 @@ // //===----------------------------------------------------------------------===// +#include +#include +#include + #include #include -#include -#include -#include -#include #include #include #include @@ -24,7 +24,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -namespace py = pybind11; +namespace nb = nanobind; namespace mlir { namespace python { @@ -53,10 +53,10 @@ namespace { /// Takes in an optional ist of operands and converts them into a SmallVector /// of MlirVlaues. Returns an empty SmallVector if the list is empty. -llvm::SmallVector wrapOperands(std::optional operandList) { +llvm::SmallVector wrapOperands(std::optional operandList) { llvm::SmallVector mlirOperands; - if (!operandList || operandList->empty()) { + if (!operandList || operandList->size() == 0) { return mlirOperands; } @@ -68,40 +68,42 @@ llvm::SmallVector wrapOperands(std::optional operandList) { PyValue *val; try { - val = py::cast(it.value()); + val = nb::cast(it.value()); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); continue; - } catch (py::cast_error &err) { + } catch (nb::cast_error &err) { // Intentionally unhandled to try sequence below first. (void)err; } try { - auto vals = py::cast(it.value()); - for (py::object v : vals) { + auto vals = nb::cast(it.value()); + for (nb::handle v : vals) { try { - val = py::cast(v); + val = nb::cast(v); if (!val) - throw py::cast_error(); + throw nb::cast_error(); mlirOperands.push_back(val->get()); - } catch (py::cast_error &err) { - throw py::value_error( + } catch (nb::cast_error &err) { + throw nb::value_error( (llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } } continue; - } catch (py::cast_error &err) { - throw py::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + + } catch (nb::cast_error &err) { + throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " must be a Value or Sequence of Values (" + err.what() + ")") - .str()); + .str() + .c_str()); } - throw py::cast_error(); + throw nb::cast_error(); } return mlirOperands; @@ -144,24 +146,24 @@ wrapRegions(std::optional> regions) { template class PyConcreteOpInterface { protected: - using ClassTy = py::class_; + using ClassTy = nb::class_; using GetTypeIDFunctionTy = MlirTypeID (*)(); public: /// Constructs an interface instance from an object that is either an /// operation or a subclass of OpView. In the latter case, only the static /// methods of the interface are accessible to the caller. - PyConcreteOpInterface(py::object object, DefaultingPyMlirContext context) + PyConcreteOpInterface(nb::object object, DefaultingPyMlirContext context) : obj(std::move(object)) { try { - operation = &py::cast(obj); - } catch (py::cast_error &) { + operation = &nb::cast(obj); + } catch (nb::cast_error &) { // Do nothing. } try { - operation = &py::cast(obj).getOperation(); - } catch (py::cast_error &) { + operation = &nb::cast(obj).getOperation(); + } catch (nb::cast_error &) { // Do nothing. } @@ -169,7 +171,7 @@ class PyConcreteOpInterface { if (!mlirOperationImplementsInterface(*operation, ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } MlirIdentifier identifier = mlirOperationGetName(*operation); @@ -177,9 +179,9 @@ class PyConcreteOpInterface { opName = std::string(stringRef.data, stringRef.length); } else { try { - opName = obj.attr("OPERATION_NAME").template cast(); - } catch (py::cast_error &) { - throw py::type_error( + opName = nb::cast(obj.attr("OPERATION_NAME")); + } catch (nb::cast_error &) { + throw nb::type_error( "Op interface does not refer to an operation or OpView class"); } @@ -187,22 +189,19 @@ class PyConcreteOpInterface { mlirStringRefCreate(opName.data(), opName.length()), context.resolve().get(), ConcreteIface::getInterfaceID())) { std::string msg = "the operation does not implement "; - throw py::value_error(msg + ConcreteIface::pyClassName); + throw nb::value_error((msg + ConcreteIface::pyClassName).c_str()); } } } /// Creates the Python bindings for this class in the given module. - static void bind(py::module &m) { - py::class_ cls(m, ConcreteIface::pyClassName, - py::module_local()); - cls.def(py::init(), py::arg("object"), - py::arg("context") = py::none(), constructorDoc) - .def_property_readonly("operation", - &PyConcreteOpInterface::getOperationObject, - operationDoc) - .def_property_readonly("opview", &PyConcreteOpInterface::getOpView, - opviewDoc); + static void bind(nb::module_ &m) { + nb::class_ cls(m, ConcreteIface::pyClassName); + cls.def(nb::init(), nb::arg("object"), + nb::arg("context").none() = nb::none(), constructorDoc) + .def_prop_ro("operation", &PyConcreteOpInterface::getOperationObject, + operationDoc) + .def_prop_ro("opview", &PyConcreteOpInterface::getOpView, opviewDoc); ConcreteIface::bindDerived(cls); } @@ -216,9 +215,9 @@ class PyConcreteOpInterface { /// Returns the operation instance from which this object was constructed. /// Throws a type error if this object was constructed from a subclass of /// OpView. - py::object getOperationObject() { + nb::object getOperationObject() { if (operation == nullptr) { - throw py::type_error("Cannot get an operation from a static interface"); + throw nb::type_error("Cannot get an operation from a static interface"); } return operation->getRef().releaseObject(); @@ -227,9 +226,9 @@ class PyConcreteOpInterface { /// Returns the opview of the operation instance from which this object was /// constructed. Throws a type error if this object was constructed form a /// subclass of OpView. - py::object getOpView() { + nb::object getOpView() { if (operation == nullptr) { - throw py::type_error("Cannot get an opview from a static interface"); + throw nb::type_error("Cannot get an opview from a static interface"); } return operation->createOpView(); @@ -242,7 +241,7 @@ class PyConcreteOpInterface { private: PyOperation *operation = nullptr; std::string opName; - py::object obj; + nb::object obj; }; /// Python wrapper for InferTypeOpInterface. This interface has only static @@ -276,7 +275,7 @@ class PyInferTypeOpInterface /// Given the arguments required to build an operation, attempts to infer its /// return types. Throws value_error on failure. std::vector - inferReturnTypes(std::optional operandList, + inferReturnTypes(std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, @@ -299,7 +298,7 @@ class PyInferTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result types"); + throw nb::value_error("Failed to infer result types"); } return inferredTypes; @@ -307,11 +306,12 @@ class PyInferTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypes", &PyInferTypeOpInterface::inferReturnTypes, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), - py::arg("properties") = py::none(), py::arg("regions") = py::none(), - py::arg("context") = py::none(), py::arg("loc") = py::none(), - inferReturnTypesDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypesDoc); } }; @@ -319,9 +319,9 @@ class PyInferTypeOpInterface class PyShapedTypeComponents { public: PyShapedTypeComponents(MlirType elementType) : elementType(elementType) {} - PyShapedTypeComponents(py::list shape, MlirType elementType) + PyShapedTypeComponents(nb::list shape, MlirType elementType) : shape(std::move(shape)), elementType(elementType), ranked(true) {} - PyShapedTypeComponents(py::list shape, MlirType elementType, + PyShapedTypeComponents(nb::list shape, MlirType elementType, MlirAttribute attribute) : shape(std::move(shape)), elementType(elementType), attribute(attribute), ranked(true) {} @@ -330,10 +330,9 @@ class PyShapedTypeComponents { : shape(other.shape), elementType(other.elementType), attribute(other.attribute), ranked(other.ranked) {} - static void bind(py::module &m) { - py::class_(m, "ShapedTypeComponents", - py::module_local()) - .def_property_readonly( + static void bind(nb::module_ &m) { + nb::class_(m, "ShapedTypeComponents") + .def_prop_ro( "element_type", [](PyShapedTypeComponents &self) { return self.elementType; }, "Returns the element type of the shaped type components.") @@ -342,57 +341,57 @@ class PyShapedTypeComponents { [](PyType &elementType) { return PyShapedTypeComponents(elementType); }, - py::arg("element_type"), + nb::arg("element_type"), "Create an shaped type components object with only the element " "type.") .def_static( "get", - [](py::list shape, PyType &elementType) { + [](nb::list shape, PyType &elementType) { return PyShapedTypeComponents(std::move(shape), elementType); }, - py::arg("shape"), py::arg("element_type"), + nb::arg("shape"), nb::arg("element_type"), "Create a ranked shaped type components object.") .def_static( "get", - [](py::list shape, PyType &elementType, PyAttribute &attribute) { + [](nb::list shape, PyType &elementType, PyAttribute &attribute) { return PyShapedTypeComponents(std::move(shape), elementType, attribute); }, - py::arg("shape"), py::arg("element_type"), py::arg("attribute"), + nb::arg("shape"), nb::arg("element_type"), nb::arg("attribute"), "Create a ranked shaped type components object with attribute.") - .def_property_readonly( + .def_prop_ro( "has_rank", [](PyShapedTypeComponents &self) -> bool { return self.ranked; }, "Returns whether the given shaped type component is ranked.") - .def_property_readonly( + .def_prop_ro( "rank", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::int_(self.shape.size()); + return nb::int_(self.shape.size()); }, "Returns the rank of the given ranked shaped type components. If " "the shaped type components does not have a rank, None is " "returned.") - .def_property_readonly( + .def_prop_ro( "shape", - [](PyShapedTypeComponents &self) -> py::object { + [](PyShapedTypeComponents &self) -> nb::object { if (!self.ranked) { - return py::none(); + return nb::none(); } - return py::list(self.shape); + return nb::list(self.shape); }, "Returns the shape of the ranked shaped type components as a list " "of integers. Returns none if the shaped type component does not " "have a rank."); } - pybind11::object getCapsule(); - static PyShapedTypeComponents createFromCapsule(pybind11::object capsule); + nb::object getCapsule(); + static PyShapedTypeComponents createFromCapsule(nb::object capsule); private: - py::list shape; + nb::list shape; MlirType elementType; MlirAttribute attribute; bool ranked{false}; @@ -424,7 +423,7 @@ class PyInferShapedTypeOpInterface if (!hasRank) { data->inferredShapedTypeComponents.emplace_back(elementType); } else { - py::list shapeList; + nb::list shapeList; for (intptr_t i = 0; i < rank; ++i) { shapeList.append(shape[i]); } @@ -436,7 +435,7 @@ class PyInferShapedTypeOpInterface /// Given the arguments required to build an operation, attempts to infer the /// shaped type components. Throws value_error on failure. std::vector inferReturnTypeComponents( - std::optional operandList, + std::optional operandList, std::optional attributes, void *properties, std::optional> regions, DefaultingPyMlirContext context, DefaultingPyLocation location) { @@ -458,7 +457,7 @@ class PyInferShapedTypeOpInterface mlirRegions.data(), &appendResultsCallback, &data); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error("Failed to infer result shape type components"); + throw nb::value_error("Failed to infer result shape type components"); } return inferredShapedTypeComponents; @@ -467,14 +466,16 @@ class PyInferShapedTypeOpInterface static void bindDerived(ClassTy &cls) { cls.def("inferReturnTypeComponents", &PyInferShapedTypeOpInterface::inferReturnTypeComponents, - py::arg("operands") = py::none(), - py::arg("attributes") = py::none(), py::arg("regions") = py::none(), - py::arg("properties") = py::none(), py::arg("context") = py::none(), - py::arg("loc") = py::none(), inferReturnTypeComponentsDoc); + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("properties").none() = nb::none(), + nb::arg("context").none() = nb::none(), + nb::arg("loc").none() = nb::none(), inferReturnTypeComponentsDoc); } }; -void populateIRInterfaces(py::module &m) { +void populateIRInterfaces(nb::module_ &m) { PyInferTypeOpInterface::bind(m); PyShapedTypeComponents::bind(m); PyInferShapedTypeOpInterface::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6727860c0..416a14218 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -7,16 +7,19 @@ //===----------------------------------------------------------------------===// #include "IRModule.h" -#include "Globals.h" -#include "PybindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" -#include "mlir-c/Support.h" +#include +#include #include #include -namespace py = pybind11; +#include "Globals.h" +#include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" +#include "mlir-c/Support.h" + +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -41,14 +44,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; - py::object loaded = py::none(); + nb::object loaded = nb::none(); for (std::string moduleName : localSearchPrefixes) { moduleName.push_back('.'); moduleName.append(dialectNamespace.data(), dialectNamespace.size()); try { - loaded = py::module::import(moduleName.c_str()); - } catch (py::error_already_set &e) { + loaded = nb::module_::import_(moduleName.c_str()); + } catch (nb::python_error &e) { if (e.matches(PyExc_ModuleNotFoundError)) { continue; } @@ -66,41 +69,39 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, - py::function pyFunc, bool replace) { - py::object &found = attributeBuilderMap[attributeKind]; + nb::callable pyFunc, bool replace) { + nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + attributeKind + "' is already registered with func: " + - py::str(found).operator std::string()) + nb::cast(nb::str(found))) .str()); } found = std::move(pyFunc); } void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, - pybind11::function typeCaster, - bool replace) { - pybind11::object &found = typeCasterMap[mlirTypeID]; + nb::callable typeCaster, bool replace) { + nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + - py::str(found).operator std::string()); + nb::cast(nb::str(found))); found = std::move(typeCaster); } void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, - pybind11::function valueCaster, - bool replace) { - pybind11::object &found = valueCasterMap[mlirTypeID]; + nb::callable valueCaster, bool replace) { + nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + - py::repr(found).cast()); + nb::cast(nb::repr(found))); found = std::move(valueCaster); } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, - py::object pyClass) { - py::object &found = dialectClassMap[dialectNamespace]; + nb::object pyClass) { + nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + dialectNamespace + "' is already registered.") @@ -110,8 +111,8 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, } void PyGlobals::registerOperationImpl(const std::string &operationName, - py::object pyClass, bool replace) { - py::object &found = operationClassMap[operationName]; + nb::object pyClass, bool replace) { + nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + "' is already registered.") @@ -120,7 +121,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, found = std::move(pyClass); } -std::optional +std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { @@ -130,7 +131,7 @@ PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { return std::nullopt; } -std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -142,7 +143,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); @@ -154,7 +155,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, return std::nullopt; } -std::optional +std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) @@ -168,7 +169,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { return std::nullopt; } -std::optional +std::optional PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Make sure dialect module is loaded. auto split = operationName.split('.'); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 172898cfd..a242ff26b 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -10,20 +10,22 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H +#include +#include + #include #include #include #include "Globals.h" -#include "PybindUtils.h" - +#include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" namespace mlir { @@ -49,7 +51,7 @@ class PyValue; template class PyObjectRef { public: - PyObjectRef(T *referrent, pybind11::object object) + PyObjectRef(T *referrent, nanobind::object object) : referrent(referrent), object(std::move(object)) { assert(this->referrent && "cannot construct PyObjectRef with null referrent"); @@ -67,13 +69,13 @@ class PyObjectRef { int getRefCount() { if (!object) return 0; - return object.ref_count(); + return Py_REFCNT(object.ptr()); } /// Releases the object held by this instance, returning it. /// This is the proper thing to return from a function that wants to return /// the reference. Note that this does not work from initializers. - pybind11::object releaseObject() { + nanobind::object releaseObject() { assert(referrent && object); referrent = nullptr; auto stolen = std::move(object); @@ -85,7 +87,7 @@ class PyObjectRef { assert(referrent && object); return referrent; } - pybind11::object getObject() { + nanobind::object getObject() { assert(referrent && object); return object; } @@ -93,7 +95,7 @@ class PyObjectRef { private: T *referrent; - pybind11::object object; + nanobind::object object; }; /// Tracks an entry in the thread context stack. New entries are pushed onto @@ -112,9 +114,9 @@ class PyThreadContextEntry { Location, }; - PyThreadContextEntry(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, - pybind11::object location) + PyThreadContextEntry(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, + nanobind::object location) : context(std::move(context)), insertionPoint(std::move(insertionPoint)), location(std::move(location)), frameKind(frameKind) {} @@ -134,26 +136,26 @@ class PyThreadContextEntry { /// Stack management. static PyThreadContextEntry *getTopOfStack(); - static pybind11::object pushContext(PyMlirContext &context); + static nanobind::object pushContext(nanobind::object context); static void popContext(PyMlirContext &context); - static pybind11::object pushInsertionPoint(PyInsertionPoint &insertionPoint); + static nanobind::object pushInsertionPoint(nanobind::object insertionPoint); static void popInsertionPoint(PyInsertionPoint &insertionPoint); - static pybind11::object pushLocation(PyLocation &location); + static nanobind::object pushLocation(nanobind::object location); static void popLocation(PyLocation &location); /// Gets the thread local stack. static std::vector &getStack(); private: - static void push(FrameKind frameKind, pybind11::object context, - pybind11::object insertionPoint, pybind11::object location); + static void push(FrameKind frameKind, nanobind::object context, + nanobind::object insertionPoint, nanobind::object location); /// An object reference to the PyContext. - pybind11::object context; + nanobind::object context; /// An object reference to the current insertion point. - pybind11::object insertionPoint; + nanobind::object insertionPoint; /// An object reference to the current location. - pybind11::object location; + nanobind::object location; // The kind of push that was performed. FrameKind frameKind; }; @@ -163,14 +165,15 @@ using PyMlirContextRef = PyObjectRef; class PyMlirContext { public: PyMlirContext() = delete; + PyMlirContext(MlirContext context); PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (py::init) method, pybind11 is quite - /// strict about needing to return a pointer that is not yet associated to - /// an py::object. Since the forContext() method acts like a pool, possibly - /// returning a recycled context, it does not satisfy this need. The usual - /// way in python to accomplish such a thing is to override __new__, but + /// For the case of a python __init__ (nanobind::init) method, pybind11 is + /// quite strict about needing to return a pointer that is not yet associated + /// to an nanobind::object. Since the forContext() method acts like a pool, + /// possibly returning a recycled context, it does not satisfy this need. The + /// usual way in python to accomplish such a thing is to override __new__, but /// that is also not supported by pybind11. Instead, we use this entry /// point which always constructs a fresh context (which cannot alias an /// existing one because it is fresh). @@ -187,17 +190,17 @@ class PyMlirContext { /// Gets a strong reference to this context, which will ensure it is kept /// alive for the life of the reference. PyMlirContextRef getRef() { - return PyMlirContextRef(this, pybind11::cast(this)); + return PyMlirContextRef(this, nanobind::cast(this)); } /// Gets a capsule wrapping the void* within the MlirContext. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyMlirContext from the MlirContext wrapped by a capsule. /// Note that PyMlirContext instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirContext /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); @@ -237,14 +240,14 @@ class PyMlirContext { size_t getLiveModuleCount(); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object context); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Attaches a Python callback as a diagnostic handler, returning a /// registration object (internally a PyDiagnosticHandler). - pybind11::object attachDiagnosticHandler(pybind11::object callback); + nanobind::object attachDiagnosticHandler(nanobind::object callback); /// Controls whether error diagnostics should be propagated to diagnostic /// handlers, instead of being captured by `ErrorCapture`. @@ -252,8 +255,6 @@ class PyMlirContext { struct ErrorCapture; private: - PyMlirContext(MlirContext context); - // Interns the mapping of live MlirContext::ptr to PyMlirContext instances, // preserving the relationship that an MlirContext maps to a single // PyMlirContext wrapper. This could be replaced in the future with an @@ -268,7 +269,7 @@ class PyMlirContext { // from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveModuleMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveModuleMap liveModules; // Interns all live operations associated with this context. Operations @@ -276,7 +277,7 @@ class PyMlirContext { // removed from this map, and while it still exists as an instance, any // attempt to access it will raise an error. using LiveOperationMap = - llvm::DenseMap>; + llvm::DenseMap>; LiveOperationMap liveOperations; bool emitErrorDiagnostics = false; @@ -324,19 +325,19 @@ class PyLocation : public BaseContextObject { MlirLocation get() const { return loc; } /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object location); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); /// Gets a capsule wrapping the void* within the MlirLocation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyLocation from the MlirLocation wrapped by a capsule. /// Note that PyLocation instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirLocation /// is taken by calling this function. - static PyLocation createFromCapsule(pybind11::object capsule); + static PyLocation createFromCapsule(nanobind::object capsule); private: MlirLocation loc; @@ -353,8 +354,8 @@ class PyDiagnostic { bool isValid() { return valid; } MlirDiagnosticSeverity getSeverity(); PyLocation getLocation(); - pybind11::str getMessage(); - pybind11::tuple getNotes(); + nanobind::str getMessage(); + nanobind::tuple getNotes(); /// Materialized diagnostic information. This is safe to access outside the /// diagnostic callback. @@ -373,7 +374,7 @@ class PyDiagnostic { /// If notes have been materialized from the diagnostic, then this will /// be populated with the corresponding objects (all castable to /// PyDiagnostic). - std::optional materializedNotes; + std::optional materializedNotes; bool valid = true; }; @@ -398,7 +399,7 @@ class PyDiagnostic { /// is no way to attach an existing handler object). class PyDiagnosticHandler { public: - PyDiagnosticHandler(MlirContext context, pybind11::object callback); + PyDiagnosticHandler(MlirContext context, nanobind::object callback); ~PyDiagnosticHandler(); bool isAttached() { return registeredID.has_value(); } @@ -407,16 +408,16 @@ class PyDiagnosticHandler { /// Detaches the handler. Does nothing if not attached. void detach(); - pybind11::object contextEnter() { return pybind11::cast(this); } - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb) { + nanobind::object contextEnter() { return nanobind::cast(this); } + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb) { detach(); } private: MlirContext context; - pybind11::object callback; + nanobind::object callback; std::optional registeredID; bool hadError = false; friend class PyMlirContext; @@ -477,12 +478,12 @@ class PyDialects : public BaseContextObject { /// objects of this type will be returned directly. class PyDialect { public: - PyDialect(pybind11::object descriptor) : descriptor(std::move(descriptor)) {} + PyDialect(nanobind::object descriptor) : descriptor(std::move(descriptor)) {} - pybind11::object getDescriptor() { return descriptor; } + nanobind::object getDescriptor() { return descriptor; } private: - pybind11::object descriptor; + nanobind::object descriptor; }; /// Wrapper around an MlirDialectRegistry. @@ -505,8 +506,8 @@ class PyDialectRegistry { operator MlirDialectRegistry() const { return registry; } MlirDialectRegistry get() const { return registry; } - pybind11::object getCapsule(); - static PyDialectRegistry createFromCapsule(pybind11::object capsule); + nanobind::object getCapsule(); + static PyDialectRegistry createFromCapsule(nanobind::object capsule); private: MlirDialectRegistry registry; @@ -542,26 +543,25 @@ class PyModule : public BaseContextObject { /// Gets a strong reference to this module. PyModuleRef getRef() { - return PyModuleRef(this, - pybind11::reinterpret_borrow(handle)); + return PyModuleRef(this, nanobind::borrow(handle)); } /// Gets a capsule wrapping the void* within the MlirModule. /// Note that the module does not (yet) provide a corresponding factory for /// constructing from a capsule as that would require uniquing PyModule /// instances, which is not currently done. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. /// Note that PyModule instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirModule /// is taken by calling this function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; - pybind11::handle handle; + nanobind::handle handle; }; class PyAsmState; @@ -574,18 +574,18 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, py::object fileObject, bool binary, + bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions); - void print(PyAsmState &state, py::object fileObject, bool binary); + void print(PyAsmState &state, nanobind::object fileObject, bool binary); - pybind11::object getAsm(bool binary, + nanobind::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. - void writeBytecode(const pybind11::object &fileObject, + void writeBytecode(const nanobind::object &fileObject, std::optional bytecodeVersion); // Implement the walk method. @@ -621,13 +621,13 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// it with a parentKeepAlive. static PyOperationRef forOperation(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Creates a detached operation. The operation must not be associated with /// any existing live operation. static PyOperationRef createDetached(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive = pybind11::object()); + nanobind::object parentKeepAlive = nanobind::object()); /// Parses a source string (either text assembly or bytecode), creating a /// detached operation. @@ -640,7 +640,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void detachFromParent() { mlirOperationRemoveFromParent(getOperation()); setDetached(); - parentKeepAlive = pybind11::object(); + parentKeepAlive = nanobind::object(); } /// Gets the backing operation. @@ -651,12 +651,11 @@ class PyOperation : public PyOperationBase, public BaseContextObject { } PyOperationRef getRef() { - return PyOperationRef( - this, pybind11::reinterpret_borrow(handle)); + return PyOperationRef(this, nanobind::borrow(handle)); } bool isAttached() { return attached; } - void setAttached(const pybind11::object &parent = pybind11::object()) { + void setAttached(const nanobind::object &parent = nanobind::object()) { assert(!attached && "operation already attached"); attached = true; } @@ -675,24 +674,24 @@ class PyOperation : public PyOperationBase, public BaseContextObject { std::optional getParentOperation(); /// Gets a capsule wrapping the void* within the MlirOperation. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyOperation from the MlirOperation wrapped by a capsule. /// Ownership of the underlying MlirOperation is taken by calling this /// function. - static pybind11::object createFromCapsule(pybind11::object capsule); + static nanobind::object createFromCapsule(nanobind::object capsule); /// Creates an operation. See corresponding python docstring. - static pybind11::object + static nanobind::object create(const std::string &name, std::optional> results, std::optional> operands, - std::optional attributes, + std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const pybind11::object &ip, + DefaultingPyLocation location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. - pybind11::object createOpView(); + nanobind::object createOpView(); /// Erases the underlying MlirOperation, removes its pointer from the /// parent context's live operations map, and sets the valid bit false. @@ -702,23 +701,23 @@ class PyOperation : public PyOperationBase, public BaseContextObject { void setInvalid() { valid = false; } /// Clones this operation. - pybind11::object clone(const pybind11::object &ip); + nanobind::object clone(const nanobind::object &ip); private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, - pybind11::object parentKeepAlive); + nanobind::object parentKeepAlive); MlirOperation operation; - pybind11::handle handle; + nanobind::handle handle; // Keeps the parent alive, regardless of whether it is an Operation or // Module. // TODO: As implemented, this facility is only sufficient for modeling the // trivial module parent back-reference. Generalize this to also account for // transitions from detached to attached and address TODOs in the // ir_operation.py regarding testing corresponding lifetime guarantees. - pybind11::object parentKeepAlive; + nanobind::object parentKeepAlive; bool attached = true; bool valid = true; @@ -733,17 +732,17 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// python types. class PyOpView : public PyOperationBase { public: - PyOpView(const pybind11::object &operationObject); + PyOpView(const nanobind::object &operationObject); PyOperation &getOperation() override { return operation; } - pybind11::object getOperationObject() { return operationObject; } + nanobind::object getOperationObject() { return operationObject; } - static pybind11::object buildGeneric( - const pybind11::object &cls, std::optional resultTypeList, - pybind11::list operandList, std::optional attributes, + static nanobind::object buildGeneric( + const nanobind::object &cls, std::optional resultTypeList, + nanobind::list operandList, std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, - const pybind11::object &maybeIp); + const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor @@ -752,12 +751,12 @@ class PyOpView : public PyOperationBase { /// /// The caller is responsible for verifying that `operation` is a valid /// operation to construct `cls` with. - static pybind11::object constructDerived(const pybind11::object &cls, - const PyOperation &operation); + static nanobind::object constructDerived(const nanobind::object &cls, + const nanobind::object &operation); private: PyOperation &operation; // For efficient, cast-free access from C++ - pybind11::object operationObject; // Holds the reference. + nanobind::object operationObject; // Holds the reference. }; /// Wrapper around an MlirRegion. @@ -830,7 +829,7 @@ class PyBlock { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirBlock. - pybind11::object getCapsule(); + nanobind::object getCapsule(); private: PyOperationRef parentOperation; @@ -858,10 +857,10 @@ class PyInsertionPoint { void insert(PyOperationBase &operationBase); /// Enter and exit the context manager. - pybind11::object contextEnter(); - void contextExit(const pybind11::object &excType, - const pybind11::object &excVal, - const pybind11::object &excTb); + static nanobind::object contextEnter(nanobind::object insertionPoint); + void contextExit(const nanobind::object &excType, + const nanobind::object &excVal, + const nanobind::object &excTb); PyBlock &getBlock() { return block; } std::optional &getRefOperation() { return refOperation; } @@ -886,13 +885,13 @@ class PyType : public BaseContextObject { MlirType get() const { return type; } /// Gets a capsule wrapping the void* within the MlirType. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyType from the MlirType wrapped by a capsule. /// Note that PyType instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirType /// is taken by calling this function. - static PyType createFromCapsule(pybind11::object capsule); + static PyType createFromCapsule(nanobind::object capsule); private: MlirType type; @@ -912,10 +911,10 @@ class PyTypeID { MlirTypeID get() { return typeID; } /// Gets a capsule wrapping the void* within the MlirTypeID. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyTypeID from the MlirTypeID wrapped by a capsule. - static PyTypeID createFromCapsule(pybind11::object capsule); + static PyTypeID createFromCapsule(nanobind::object capsule); private: MlirTypeID typeID; @@ -932,7 +931,7 @@ class PyConcreteType : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirType); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -945,34 +944,38 @@ class PyConcreteType : public BaseTy { static MlirType castFrom(PyType &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast type to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast type to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_type")); + static void bind(nanobind::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_type")); cls.def_static( "isinstance", [](PyType &otherType) -> bool { return DerivedTy::isaFunction(otherType); }, - pybind11::arg("other")); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + nanobind::arg("other")); + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyType &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyType &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -986,8 +989,8 @@ class PyConcreteType : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function( - [](PyType pyType) -> DerivedTy { return pyType; })); + nanobind::cast(nanobind::cpp_function( + [](PyType pyType) -> DerivedTy { return pyType; }))); } DerivedTy::bindDerived(cls); @@ -1008,13 +1011,13 @@ class PyAttribute : public BaseContextObject { MlirAttribute get() const { return attr; } /// Gets a capsule wrapping the void* within the MlirAttribute. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAttribute from the MlirAttribute wrapped by a capsule. /// Note that PyAttribute instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAttribute /// is taken by calling this function. - static PyAttribute createFromCapsule(pybind11::object capsule); + static PyAttribute createFromCapsule(nanobind::object capsule); private: MlirAttribute attr; @@ -1054,7 +1057,7 @@ class PyConcreteAttribute : public BaseTy { // Derived classes must define statics for: // IsAFunctionTy isaFunction // const char *pyClassName - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; using IsAFunctionTy = bool (*)(MlirAttribute); using GetTypeIDFunctionTy = MlirTypeID (*)(); static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; @@ -1067,37 +1070,45 @@ class PyConcreteAttribute : public BaseTy { static MlirAttribute castFrom(PyAttribute &orig) { if (!DerivedTy::isaFunction(orig)) { - auto origRepr = pybind11::repr(pybind11::cast(orig)).cast(); - throw py::value_error((llvm::Twine("Cannot cast attribute to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str()); + auto origRepr = + nanobind::cast(nanobind::repr(nanobind::cast(orig))); + throw nanobind::value_error((llvm::Twine("Cannot cast attribute to ") + + DerivedTy::pyClassName + " (from " + + origRepr + ")") + .str() + .c_str()); } return orig; } - static void bind(pybind11::module &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName, pybind11::buffer_protocol(), - pybind11::module_local()); - cls.def(pybind11::init(), pybind11::keep_alive<0, 1>(), - pybind11::arg("cast_from_attr")); + static void bind(nanobind::module_ &m, PyType_Slot *slots = nullptr) { + ClassTy cls; + if (slots) { + cls = ClassTy(m, DerivedTy::pyClassName, nanobind::type_slots(slots)); + } else { + cls = ClassTy(m, DerivedTy::pyClassName); + } + cls.def(nanobind::init(), nanobind::keep_alive<0, 1>(), + nanobind::arg("cast_from_attr")); cls.def_static( "isinstance", [](PyAttribute &otherAttr) -> bool { return DerivedTy::isaFunction(otherAttr); }, - pybind11::arg("other")); - cls.def_property_readonly( + nanobind::arg("other")); + cls.def_prop_ro( "type", [](PyAttribute &attr) { return mlirAttributeGetType(attr); }); - cls.def_property_readonly_static( - "static_typeid", [](py::object & /*class*/) -> MlirTypeID { + cls.def_prop_ro_static( + "static_typeid", [](nanobind::object & /*class*/) -> MlirTypeID { if (DerivedTy::getTypeIdFunction) return DerivedTy::getTypeIdFunction(); - throw py::attribute_error( - (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")).str()); + throw nanobind::attribute_error( + (DerivedTy::pyClassName + llvm::Twine(" has no typeid.")) + .str() + .c_str()); }); - cls.def_property_readonly("typeid", [](PyAttribute &self) { - return py::cast(self).attr("typeid").cast(); + cls.def_prop_ro("typeid", [](PyAttribute &self) { + return nanobind::cast(nanobind::cast(self).attr("typeid")); }); cls.def("__repr__", [](DerivedTy &self) { PyPrintAccumulator printAccum; @@ -1112,9 +1123,10 @@ class PyConcreteAttribute : public BaseTy { if (DerivedTy::getTypeIdFunction) { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), - pybind11::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { - return pyAttribute; - })); + nanobind::cast( + nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { + return pyAttribute; + }))); } DerivedTy::bindDerived(cls); @@ -1146,13 +1158,13 @@ class PyValue { void checkValid() { return parentOperation->checkValid(); } /// Gets a capsule wrapping the void* within the MlirValue. - pybind11::object getCapsule(); + nanobind::object getCapsule(); - pybind11::object maybeDownCast(); + nanobind::object maybeDownCast(); /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. - static PyValue createFromCapsule(pybind11::object capsule); + static PyValue createFromCapsule(nanobind::object capsule); private: PyOperationRef parentOperation; @@ -1169,13 +1181,13 @@ class PyAffineExpr : public BaseContextObject { MlirAffineExpr get() const { return affineExpr; } /// Gets a capsule wrapping the void* within the MlirAffineExpr. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineExpr from the MlirAffineExpr wrapped by a capsule. /// Note that PyAffineExpr instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineExpr /// is taken by calling this function. - static PyAffineExpr createFromCapsule(pybind11::object capsule); + static PyAffineExpr createFromCapsule(nanobind::object capsule); PyAffineExpr add(const PyAffineExpr &other) const; PyAffineExpr mul(const PyAffineExpr &other) const; @@ -1196,13 +1208,13 @@ class PyAffineMap : public BaseContextObject { MlirAffineMap get() const { return affineMap; } /// Gets a capsule wrapping the void* within the MlirAffineMap. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyAffineMap from the MlirAffineMap wrapped by a capsule. /// Note that PyAffineMap instances are uniqued, so the returned object /// may be a pre-existing object. Ownership of the underlying MlirAffineMap /// is taken by calling this function. - static PyAffineMap createFromCapsule(pybind11::object capsule); + static PyAffineMap createFromCapsule(nanobind::object capsule); private: MlirAffineMap affineMap; @@ -1217,12 +1229,12 @@ class PyIntegerSet : public BaseContextObject { MlirIntegerSet get() const { return integerSet; } /// Gets a capsule wrapping the void* within the MlirIntegerSet. - pybind11::object getCapsule(); + nanobind::object getCapsule(); /// Creates a PyIntegerSet from the MlirAffineMap wrapped by a capsule. /// Note that PyIntegerSet instances may be uniqued, so the returned object /// may be a pre-existing object. Integer sets are owned by the context. - static PyIntegerSet createFromCapsule(pybind11::object capsule); + static PyIntegerSet createFromCapsule(nanobind::object capsule); private: MlirIntegerSet integerSet; @@ -1239,7 +1251,7 @@ class PySymbolTable { /// Returns the symbol (opview) with the given name, throws if there is no /// such symbol in the table. - pybind11::object dunderGetItem(const std::string &name); + nanobind::object dunderGetItem(const std::string &name); /// Removes the given operation from the symbol table and erases it. void erase(PyOperationBase &symbol); @@ -1269,7 +1281,7 @@ class PySymbolTable { /// Walks all symbol tables under and including 'from'. static void walkSymbolTables(PyOperationBase &from, bool allSymUsesVisible, - pybind11::object callback); + nanobind::object callback); /// Casts the bindings class into the C API structure. operator MlirSymbolTable() { return symbolTable; } @@ -1289,16 +1301,16 @@ struct MLIRError { std::vector errorDiagnostics; }; -void populateIRAffine(pybind11::module &m); -void populateIRAttributes(pybind11::module &m); -void populateIRCore(pybind11::module &m); -void populateIRInterfaces(pybind11::module &m); -void populateIRTypes(pybind11::module &m); +void populateIRAffine(nanobind::module_ &m); +void populateIRAttributes(nanobind::module_ &m); +void populateIRCore(nanobind::module_ &m); +void populateIRInterfaces(nanobind::module_ &m); +void populateIRTypes(nanobind::module_ &m); } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template <> @@ -1309,6 +1321,6 @@ struct type_caster : MlirDefaultingCaster {}; } // namespace detail -} // namespace pybind11 +} // namespace nanobind #endif // MLIR_BINDINGS_PYTHON_IRMODULES_H diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 6f192bc4b..5cfa51142 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -6,19 +6,26 @@ // //===----------------------------------------------------------------------===// +// clang-format off #include "IRModule.h" +#include "mlir/Bindings/Python/IRTypes.h" +// clang-format on -#include "PybindUtils.h" +#include +#include +#include +#include +#include -#include "mlir/Bindings/Python/IRTypes.h" +#include +#include "IRModule.h" +#include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir-c/Support.h" -#include - -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -48,7 +55,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signless integer type"); c.def_static( "get_signed", @@ -56,7 +63,7 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeSignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create a signed integer type"); c.def_static( "get_unsigned", @@ -64,25 +71,25 @@ class PyIntegerType : public PyConcreteType { MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); return PyIntegerType(context->getRef(), t); }, - py::arg("width"), py::arg("context") = py::none(), + nb::arg("width"), nb::arg("context").none() = nb::none(), "Create an unsigned integer type"); - c.def_property_readonly( + c.def_prop_ro( "width", [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, "Returns the width of the integer type"); - c.def_property_readonly( + c.def_prop_ro( "is_signless", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSignless(self); }, "Returns whether this is a signless integer"); - c.def_property_readonly( + c.def_prop_ro( "is_signed", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, "Returns whether this is a signed integer"); - c.def_property_readonly( + c.def_prop_ro( "is_unsigned", [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsUnsigned(self); @@ -107,7 +114,7 @@ class PyIndexType : public PyConcreteType { MlirType t = mlirIndexTypeGet(context->get()); return PyIndexType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a index type."); + nb::arg("context").none() = nb::none(), "Create a index type."); } }; @@ -118,7 +125,7 @@ class PyFloatType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, "Returns the width of the floating-point type"); } @@ -141,7 +148,7 @@ class PyFloat4E2M1FNType MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); return PyFloat4E2M1FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float4_e2m1fn type."); + nb::arg("context").none() = nb::none(), "Create a float4_e2m1fn type."); } }; @@ -162,7 +169,7 @@ class PyFloat6E2M3FNType MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); return PyFloat6E2M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e2m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e2m3fn type."); } }; @@ -183,7 +190,7 @@ class PyFloat6E3M2FNType MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); return PyFloat6E3M2FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float6_e3m2fn type."); + nb::arg("context").none() = nb::none(), "Create a float6_e3m2fn type."); } }; @@ -204,7 +211,7 @@ class PyFloat8E4M3FNType MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); return PyFloat8E4M3FNType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fn type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3fn type."); } }; @@ -224,7 +231,7 @@ class PyFloat8E5M2Type : public PyConcreteType { MlirType t = mlirFloat8E5M2TypeGet(context->get()); return PyFloat8E5M2Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e5m2 type."); } }; @@ -244,7 +251,7 @@ class PyFloat8E4M3Type : public PyConcreteType { MlirType t = mlirFloat8E4M3TypeGet(context->get()); return PyFloat8E4M3Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e4m3 type."); } }; @@ -265,7 +272,8 @@ class PyFloat8E4M3FNUZType MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); return PyFloat8E4M3FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3fnuz type."); } }; @@ -286,7 +294,8 @@ class PyFloat8E4M3B11FNUZType MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); return PyFloat8E4M3B11FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e4m3b11fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e4m3b11fnuz type."); } }; @@ -307,7 +316,8 @@ class PyFloat8E5M2FNUZType MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); return PyFloat8E5M2FNUZType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e5m2fnuz type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e5m2fnuz type."); } }; @@ -327,7 +337,7 @@ class PyFloat8E3M4Type : public PyConcreteType { MlirType t = mlirFloat8E3M4TypeGet(context->get()); return PyFloat8E3M4Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e3m4 type."); + nb::arg("context").none() = nb::none(), "Create a float8_e3m4 type."); } }; @@ -348,7 +358,8 @@ class PyFloat8E8M0FNUType MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); return PyFloat8E8M0FNUType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a float8_e8m0fnu type."); + nb::arg("context").none() = nb::none(), + "Create a float8_e8m0fnu type."); } }; @@ -368,7 +379,7 @@ class PyBF16Type : public PyConcreteType { MlirType t = mlirBF16TypeGet(context->get()); return PyBF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a bf16 type."); + nb::arg("context").none() = nb::none(), "Create a bf16 type."); } }; @@ -388,7 +399,7 @@ class PyF16Type : public PyConcreteType { MlirType t = mlirF16TypeGet(context->get()); return PyF16Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f16 type."); + nb::arg("context").none() = nb::none(), "Create a f16 type."); } }; @@ -408,7 +419,7 @@ class PyTF32Type : public PyConcreteType { MlirType t = mlirTF32TypeGet(context->get()); return PyTF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a tf32 type."); + nb::arg("context").none() = nb::none(), "Create a tf32 type."); } }; @@ -428,7 +439,7 @@ class PyF32Type : public PyConcreteType { MlirType t = mlirF32TypeGet(context->get()); return PyF32Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f32 type."); + nb::arg("context").none() = nb::none(), "Create a f32 type."); } }; @@ -448,7 +459,7 @@ class PyF64Type : public PyConcreteType { MlirType t = mlirF64TypeGet(context->get()); return PyF64Type(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a f64 type."); + nb::arg("context").none() = nb::none(), "Create a f64 type."); } }; @@ -468,7 +479,7 @@ class PyNoneType : public PyConcreteType { MlirType t = mlirNoneTypeGet(context->get()); return PyNoneType(context->getRef(), t); }, - py::arg("context") = py::none(), "Create a none type."); + nb::arg("context").none() = nb::none(), "Create a none type."); } }; @@ -490,14 +501,15 @@ class PyComplexType : public PyConcreteType { MlirType t = mlirComplexTypeGet(elementType); return PyComplexType(elementType.getContext(), t); } - throw py::value_error( + throw nb::value_error( (Twine("invalid '") + - py::repr(py::cast(elementType)).cast() + + nb::cast(nb::repr(nb::cast(elementType))) + "' and expected floating point or integer type.") - .str()); + .str() + .c_str()); }, "Create a complex type"); - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyComplexType &self) { return mlirComplexTypeGetElementType(self); }, "Returns element type."); @@ -508,22 +520,22 @@ class PyComplexType : public PyConcreteType { // Shaped Type Interface - ShapedType void mlir::PyShapedType::bindDerived(ClassTy &c) { - c.def_property_readonly( + c.def_prop_ro( "element_type", [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); }, "Returns the element type of the shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_rank", [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); }, "Returns whether the given shaped type is ranked."); - c.def_property_readonly( + c.def_prop_ro( "rank", [](PyShapedType &self) { self.requireHasRank(); return mlirShapedTypeGetRank(self); }, "Returns the rank of the given ranked shaped type."); - c.def_property_readonly( + c.def_prop_ro( "has_static_shape", [](PyShapedType &self) -> bool { return mlirShapedTypeHasStaticShape(self); @@ -535,7 +547,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicDim(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); c.def( @@ -544,12 +556,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeGetDimSize(self, dim); }, - py::arg("dim"), + nb::arg("dim"), "Returns the dim-th dimension of the given ranked shaped type."); c.def_static( "is_dynamic_size", [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); c.def( @@ -558,10 +570,10 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { self.requireHasRank(); return mlirShapedTypeIsDynamicStrideOrOffset(val); }, - py::arg("dim_size"), + nb::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); - c.def_property_readonly( + c.def_prop_ro( "shape", [](PyShapedType &self) { self.requireHasRank(); @@ -587,7 +599,7 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { void mlir::PyShapedType::requireHasRank() { if (!mlirShapedTypeHasRank(*this)) { - throw py::value_error( + throw nb::value_error( "calling this method requires that the type has a rank."); } } @@ -607,15 +619,15 @@ class PyVectorType : public PyConcreteType { using PyConcreteType::PyConcreteType; static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::get, py::arg("shape"), - py::arg("element_type"), py::kw_only(), - py::arg("scalable") = py::none(), - py::arg("scalable_dims") = py::none(), - py::arg("loc") = py::none(), "Create a vector type") - .def_property_readonly( + c.def_static("get", &PyVectorType::get, nb::arg("shape"), + nb::arg("element_type"), nb::kw_only(), + nb::arg("scalable").none() = nb::none(), + nb::arg("scalable_dims").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a vector type") + .def_prop_ro( "scalable", [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_property_readonly("scalable_dims", [](MlirType self) { + .def_prop_ro("scalable_dims", [](MlirType self) { std::vector scalableDims; size_t rank = static_cast(mlirShapedTypeGetRank(self)); scalableDims.reserve(rank); @@ -627,11 +639,11 @@ class PyVectorType : public PyConcreteType { private: static PyVectorType get(std::vector shape, PyType &elementType, - std::optional scalable, + std::optional scalable, std::optional> scalableDims, DefaultingPyLocation loc) { if (scalable && scalableDims) { - throw py::value_error("'scalable' and 'scalable_dims' kwargs " + throw nb::value_error("'scalable' and 'scalable_dims' kwargs " "are mutually exclusive."); } @@ -639,10 +651,10 @@ class PyVectorType : public PyConcreteType { MlirType type; if (scalable) { if (scalable->size() != shape.size()) - throw py::value_error("Expected len(scalable) == len(shape)."); + throw nb::value_error("Expected len(scalable) == len(shape)."); SmallVector scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const py::handle &h) { return h.cast(); })); + *scalable, [](const nb::handle &h) { return nb::cast(h); })); type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), scalableDimFlags.data(), elementType); @@ -650,7 +662,7 @@ class PyVectorType : public PyConcreteType { SmallVector scalableDimFlags(shape.size(), false); for (int64_t dim : *scalableDims) { if (static_cast(dim) >= scalableDimFlags.size() || dim < 0) - throw py::value_error("Scalable dimension index out of bounds."); + throw nb::value_error("Scalable dimension index out of bounds."); scalableDimFlags[dim] = true; } type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), @@ -689,17 +701,17 @@ class PyRankedTensorType throw MLIRError("Invalid type", errors.take()); return PyRankedTensorType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("encoding") = py::none(), py::arg("loc") = py::none(), - "Create a ranked tensor type"); - c.def_property_readonly( - "encoding", - [](PyRankedTensorType &self) -> std::optional { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return encoding; - }); + nb::arg("shape"), nb::arg("element_type"), + nb::arg("encoding").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a ranked tensor type"); + c.def_prop_ro("encoding", + [](PyRankedTensorType &self) -> std::optional { + MlirAttribute encoding = + mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return encoding; + }); } }; @@ -723,7 +735,7 @@ class PyUnrankedTensorType throw MLIRError("Invalid type", errors.take()); return PyUnrankedTensorType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("loc") = py::none(), + nb::arg("element_type"), nb::arg("loc").none() = nb::none(), "Create a unranked tensor type"); } }; @@ -754,10 +766,11 @@ class PyMemRefType : public PyConcreteType { throw MLIRError("Invalid type", errors.take()); return PyMemRefType(elementType.getContext(), t); }, - py::arg("shape"), py::arg("element_type"), - py::arg("layout") = py::none(), py::arg("memory_space") = py::none(), - py::arg("loc") = py::none(), "Create a memref type") - .def_property_readonly( + nb::arg("shape"), nb::arg("element_type"), + nb::arg("layout").none() = nb::none(), + nb::arg("memory_space").none() = nb::none(), + nb::arg("loc").none() = nb::none(), "Create a memref type") + .def_prop_ro( "layout", [](PyMemRefType &self) -> MlirAttribute { return mlirMemRefTypeGetLayout(self); @@ -775,14 +788,14 @@ class PyMemRefType : public PyConcreteType { return {strides, offset}; }, "The strides and offset of the MemRef type.") - .def_property_readonly( + .def_prop_ro( "affine_map", [](PyMemRefType &self) -> PyAffineMap { MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); return PyAffineMap(self.getContext(), map); }, "The layout of the MemRef type as an affine map.") - .def_property_readonly( + .def_prop_ro( "memory_space", [](PyMemRefType &self) -> std::optional { MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); @@ -820,9 +833,9 @@ class PyUnrankedMemRefType throw MLIRError("Invalid type", errors.take()); return PyUnrankedMemRefType(elementType.getContext(), t); }, - py::arg("element_type"), py::arg("memory_space"), - py::arg("loc") = py::none(), "Create a unranked memref type") - .def_property_readonly( + nb::arg("element_type"), nb::arg("memory_space").none(), + nb::arg("loc").none() = nb::none(), "Create a unranked memref type") + .def_prop_ro( "memory_space", [](PyUnrankedMemRefType &self) -> std::optional { MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); @@ -851,15 +864,15 @@ class PyTupleType : public PyConcreteType { elements.data()); return PyTupleType(context->getRef(), t); }, - py::arg("elements"), py::arg("context") = py::none(), + nb::arg("elements"), nb::arg("context").none() = nb::none(), "Create a tuple type"); c.def( "get_type", [](PyTupleType &self, intptr_t pos) { return mlirTupleTypeGetType(self, pos); }, - py::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_property_readonly( + nb::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( "num_types", [](PyTupleType &self) -> intptr_t { return mlirTupleTypeGetNumTypes(self); @@ -887,13 +900,14 @@ class PyFunctionType : public PyConcreteType { results.size(), results.data()); return PyFunctionType(context->getRef(), t); }, - py::arg("inputs"), py::arg("results"), py::arg("context") = py::none(), + nb::arg("inputs"), nb::arg("results"), + nb::arg("context").none() = nb::none(), "Gets a FunctionType from a list of input and result types"); - c.def_property_readonly( + c.def_prop_ro( "inputs", [](PyFunctionType &self) { MlirType t = self; - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; ++i) { types.append(mlirFunctionTypeGetInput(t, i)); @@ -901,10 +915,10 @@ class PyFunctionType : public PyConcreteType { return types; }, "Returns the list of input types in the FunctionType."); - c.def_property_readonly( + c.def_prop_ro( "results", [](PyFunctionType &self) { - py::list types; + nb::list types; for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; ++i) { types.append(mlirFunctionTypeGetResult(self, i)); @@ -938,21 +952,21 @@ class PyOpaqueType : public PyConcreteType { toMlirStringRef(typeData)); return PyOpaqueType(context->getRef(), type); }, - py::arg("dialect_namespace"), py::arg("buffer"), - py::arg("context") = py::none(), + nb::arg("dialect_namespace"), nb::arg("buffer"), + nb::arg("context").none() = nb::none(), "Create an unregistered (opaque) dialect type."); - c.def_property_readonly( + c.def_prop_ro( "dialect_namespace", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the dialect namespace for the Opaque type as a string."); - c.def_property_readonly( + c.def_prop_ro( "data", [](PyOpaqueType &self) { MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return py::str(stringRef.data, stringRef.length); + return nb::str(stringRef.data, stringRef.length); }, "Returns the data for the Opaque type as a string."); } @@ -960,7 +974,7 @@ class PyOpaqueType : public PyConcreteType { } // namespace -void mlir::python::populateIRTypes(py::module &m) { +void mlir::python::populateIRTypes(nb::module_ &m) { PyIntegerType::bind(m); PyFloatType::bind(m); PyIndexType::bind(m); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c2702190..e5e64a921 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,29 +6,31 @@ // //===----------------------------------------------------------------------===// -#include "PybindUtils.h" +#include +#include #include "Globals.h" #include "IRModule.h" +#include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlir, m) { +NB_MODULE(_mlir, m) { m.doc() = "MLIR Python Native Extension"; - py::class_(m, "_Globals", py::module_local()) - .def_property("dialect_search_modules", - &PyGlobals::getDialectSearchPrefixes, - &PyGlobals::setDialectSearchPrefixes) + nb::class_(m, "_Globals") + .def_prop_rw("dialect_search_modules", + &PyGlobals::getDialectSearchPrefixes, + &PyGlobals::setDialectSearchPrefixes) .def( "append_dialect_search_prefix", [](PyGlobals &self, std::string moduleName) { @@ -45,22 +47,21 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, py::kw_only(), + "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python // resources) properly. - m.attr("globals") = - py::cast(new PyGlobals, py::return_value_policy::take_ownership); + m.attr("globals") = nb::cast(new PyGlobals, nb::rv_policy::take_ownership); // Registration decorators. m.def( "register_dialect", - [](py::type pyClass) { + [](nb::type_object pyClass) { std::string dialectNamespace = - pyClass.attr("DIALECT_NAMESPACE").cast(); + nanobind::cast(pyClass.attr("DIALECT_NAMESPACE")); PyGlobals::get().registerDialectImpl(dialectNamespace, pyClass); return pyClass; }, @@ -68,45 +69,46 @@ PYBIND11_MODULE(_mlir, m) { "Class decorator for registering a custom Dialect wrapper"); m.def( "register_operation", - [](const py::type &dialectClass, bool replace) -> py::cpp_function { - return py::cpp_function( - [dialectClass, replace](py::type opClass) -> py::type { + [](const nb::type_object &dialectClass, bool replace) -> nb::object { + return nb::cpp_function( + [dialectClass, + replace](nb::type_object opClass) -> nb::type_object { std::string operationName = - opClass.attr("OPERATION_NAME").cast(); + nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); // Dict-stuff the new opClass by name onto the dialect class. - py::object opClassName = opClass.attr("__name__"); + nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; return opClass; }); }, - "dialect_class"_a, py::kw_only(), "replace"_a = false, + "dialect_class"_a, nb::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function([mlirTypeID, - replace](py::object typeCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function([mlirTypeID, replace]( + nb::callable typeCaster) -> nb::object { PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); return typeCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { - return py::cpp_function( - [mlirTypeID, replace](py::object valueCaster) -> py::object { + [](MlirTypeID mlirTypeID, bool replace) -> nb::object { + return nb::cpp_function( + [mlirTypeID, replace](nb::callable valueCaster) -> nb::object { PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, replace); return valueCaster; }); }, - "typeid"_a, py::kw_only(), "replace"_a = false, + "typeid"_a, nb::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h similarity index 85% rename from mlir/lib/Bindings/Python/PybindUtils.h rename to mlir/lib/Bindings/Python/NanobindUtils.h index 38462ac8b..3b0f7f698 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -1,4 +1,5 @@ -//===- PybindUtils.h - Utilities for interop with pybind11 ------*- C++ -*-===// +//===- NanobindUtils.h - Utilities for interop with nanobind ------*- C++ +//-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -9,13 +10,21 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H +#include + #include "mlir-c/Support.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" -#include -#include +template <> +struct std::iterator_traits { + using value_type = nanobind::handle; + using reference = const value_type; + using pointer = void; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; +}; namespace mlir { namespace python { @@ -54,14 +63,14 @@ class Defaulting { } // namespace python } // namespace mlir -namespace pybind11 { +namespace nanobind { namespace detail { template struct MlirDefaultingCaster { - PYBIND11_TYPE_CASTER(DefaultingTy, _(DefaultingTy::kTypeDescription)); + NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)); - bool load(pybind11::handle src, bool) { + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { // Note that we do want an exception to propagate from here as it will be // the most informative. @@ -76,20 +85,20 @@ struct MlirDefaultingCaster { // code to produce nice error messages (other than "Cannot cast..."). try { value = DefaultingTy{ - pybind11::cast(src)}; + nanobind::cast(src)}; return true; } catch (std::exception &) { return false; } } - static handle cast(DefaultingTy src, return_value_policy policy, - handle parent) { - return pybind11::cast(src, policy); + static handle from_cpp(DefaultingTy src, rv_policy policy, + cleanup_list *cleanup) noexcept { + return nanobind::cast(src, policy); } }; } // namespace detail -} // namespace pybind11 +} // namespace nanobind //------------------------------------------------------------------------------ // Conversion utilities. @@ -100,7 +109,7 @@ namespace mlir { /// Accumulates into a python string from a method that accepts an /// MlirStringCallback. struct PyPrintAccumulator { - pybind11::list parts; + nanobind::list parts; void *getUserData() { return this; } @@ -108,15 +117,15 @@ struct PyPrintAccumulator { return [](MlirStringRef part, void *userData) { PyPrintAccumulator *printAccum = static_cast(userData); - pybind11::str pyPart(part.data, + nanobind::str pyPart(part.data, part.length); // Decodes as UTF-8 by default. printAccum->parts.append(std::move(pyPart)); }; } - pybind11::str join() { - pybind11::str delim("", 0); - return delim.attr("join")(parts); + nanobind::str join() { + nanobind::str delim("", 0); + return nanobind::cast(delim.attr("join")(parts)); } }; @@ -124,21 +133,21 @@ struct PyPrintAccumulator { /// or binary. class PyFileAccumulator { public: - PyFileAccumulator(const pybind11::object &fileObject, bool binary) + PyFileAccumulator(const nanobind::object &fileObject, bool binary) : pyWriteFunction(fileObject.attr("write")), binary(binary) {} void *getUserData() { return this; } MlirStringCallback getCallback() { return [](MlirStringRef part, void *userData) { - pybind11::gil_scoped_acquire acquire; + nanobind::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. - pybind11::bytes pyBytes(part.data, part.length); + nanobind::bytes pyBytes(part.data, part.length); accum->pyWriteFunction(pyBytes); } else { - pybind11::str pyStr(part.data, + nanobind::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. accum->pyWriteFunction(pyStr); } @@ -146,7 +155,7 @@ class PyFileAccumulator { } private: - pybind11::object pyWriteFunction; + nanobind::object pyWriteFunction; bool binary; }; @@ -163,17 +172,17 @@ struct PySinglePartStringAccumulator { assert(!accum->invoked && "PySinglePartStringAccumulator called back multiple times"); accum->invoked = true; - accum->value = pybind11::str(part.data, part.length); + accum->value = nanobind::str(part.data, part.length); }; } - pybind11::str takeValue() { + nanobind::str takeValue() { assert(invoked && "PySinglePartStringAccumulator not called back"); return std::move(value); } private: - pybind11::str value; + nanobind::str value; bool invoked = false; }; @@ -208,7 +217,7 @@ struct PySinglePartStringAccumulator { template class Sliceable { protected: - using ClassTy = pybind11::class_; + using ClassTy = nanobind::class_; /// Transforms `index` into a legal value to access the underlying sequence. /// Returns <0 on failure. @@ -237,7 +246,7 @@ class Sliceable { /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - pybind11::object getItem(intptr_t index) { + nanobind::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { @@ -250,20 +259,20 @@ class Sliceable { ->getRawElement(linearizeIndex(index)) .maybeDownCast(); else - return pybind11::cast( + return nanobind::cast( static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given /// slice. Returns a nullptr object on failure. - pybind11::object getItemSlice(PyObject *slice) { + nanobind::object getItemSlice(PyObject *slice) { ssize_t start, stop, extraStep, sliceLength; if (PySlice_GetIndicesEx(slice, length, &start, &stop, &extraStep, &sliceLength) != 0) { PyErr_SetString(PyExc_IndexError, "index out of range"); return {}; } - return pybind11::cast(static_cast(this)->slice( + return nanobind::cast(static_cast(this)->slice( startIndex + start * step, sliceLength, step * extraStep)); } @@ -279,7 +288,7 @@ class Sliceable { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { - throw pybind11::index_error("index out of range"); + throw nanobind::index_error("index out of range"); } return static_cast(this)->getRawElement(linearizeIndex(index)); @@ -304,39 +313,38 @@ class Sliceable { } /// Binds the indexing and length methods in the Python class. - static void bind(pybind11::module &m) { - auto clazz = pybind11::class_(m, Derived::pyClassName, - pybind11::module_local()) + static void bind(nanobind::module_ &m) { + auto clazz = nanobind::class_(m, Derived::pyClassName) .def("__add__", &Sliceable::dunderAdd); Derived::bindDerived(clazz); // Manually implement the sequence protocol via the C API. We do this - // because it is approx 4x faster than via pybind11, largely because that + // because it is approx 4x faster than via nanobind, largely because that // formulation requires a C++ exception to be thrown to detect end of // sequence. // Since we are in a C-context, any C++ exception that happens here // will terminate the program. There is nothing in this implementation // that should throw in a non-terminal way, so we forgo further // exception marshalling. - // See: https://github.com/pybind/pybind11/issues/2842 + // See: https://github.com/pybind/nanobind/issues/2842 auto heap_type = reinterpret_cast(clazz.ptr()); assert(heap_type->ht_type.tp_flags & Py_TPFLAGS_HEAPTYPE && "must be heap type"); heap_type->as_sequence.sq_length = +[](PyObject *rawSelf) -> Py_ssize_t { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->length; }; // sq_item is called as part of the sequence protocol for iteration, // list construction, etc. heap_type->as_sequence.sq_item = +[](PyObject *rawSelf, Py_ssize_t index) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); return self->getItem(index).release().ptr(); }; // mp_subscript is used for both slices and integer lookups. heap_type->as_mapping.mp_subscript = +[](PyObject *rawSelf, PyObject *rawSubscript) -> PyObject * { - auto self = pybind11::cast(rawSelf); + auto self = nanobind::cast(nanobind::handle(rawSelf)); Py_ssize_t index = PyNumber_AsSsize_t(rawSubscript, PyExc_IndexError); if (!PyErr_Occurred()) { // Integer indexing. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index e991deaae..b5dce4fe4 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,12 +8,16 @@ #include "Pass.h" +#include +#include +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" -namespace py = pybind11; -using namespace py::literals; +namespace nb = nanobind; +using namespace nb::literals; using namespace mlir; using namespace mlir::python; @@ -34,16 +38,15 @@ class PyPassManager { MlirPassManager get() { return passManager; } void release() { passManager.ptr = nullptr; } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonPassManagerToCapsule(get())); + nb::object getCapsule() { + return nb::steal(mlirPythonPassManagerToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirPassManager rawPm = mlirPythonCapsuleToPassManager(capsule.ptr()); if (mlirPassManagerIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyPassManager(rawPm), py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyPassManager(rawPm), nb::rv_policy::move); } private: @@ -53,22 +56,23 @@ class PyPassManager { } // namespace /// Create the `mlir.passmanager` here. -void mlir::python::populatePassManagerSubmodule(py::module &m) { +void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "PassManager", py::module_local()) - .def(py::init<>([](const std::string &anchorOp, - DefaultingPyMlirContext context) { - MlirPassManager passManager = mlirPassManagerCreateOnOperation( - context->get(), - mlirStringRefCreate(anchorOp.data(), anchorOp.size())); - return new PyPassManager(passManager); - }), - "anchor_op"_a = py::str("any"), "context"_a = py::none(), - "Create a new PassManager for the current (or provided) Context.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyPassManager::getCapsule) + nb::class_(m, "PassManager") + .def( + "__init__", + [](PyPassManager &self, const std::string &anchorOp, + DefaultingPyMlirContext context) { + MlirPassManager passManager = mlirPassManagerCreateOnOperation( + context->get(), + mlirStringRefCreate(anchorOp.data(), anchorOp.size())); + new (&self) PyPassManager(passManager); + }, + "anchor_op"_a = nb::str("any"), "context"_a.none() = nb::none(), + "Create a new PassManager for the current (or provided) Context.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyPassManager::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyPassManager::createFromCapsule) .def("_testing_release", &PyPassManager::release, "Releases (leaks) the backing pass manager (testing)") @@ -101,9 +105,9 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { "print_before_all"_a = false, "print_after_all"_a = true, "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, - "large_elements_limit"_a = py::none(), "enable_debug_info"_a = false, - "print_generic_op_form"_a = false, - "tree_printing_dir_path"_a = py::none(), + "large_elements_limit"_a.none() = nb::none(), + "enable_debug_info"_a = false, "print_generic_op_form"_a = false, + "tree_printing_dir_path"_a.none() = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") .def( "enable_verifier", @@ -121,10 +125,10 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); return new PyPassManager(passManager); }, - "pipeline"_a, "context"_a = py::none(), + "pipeline"_a, "context"_a.none() = nb::none(), "Parse a textual pass-pipeline and return a top-level PassManager " "that can be applied on a Module. Throw a ValueError if the pipeline " "can't be parsed") @@ -137,7 +141,7 @@ void mlir::python::populatePassManagerSubmodule(py::module &m) { mlirStringRefCreate(pipeline.data(), pipeline.size()), errorMsg.getCallback(), errorMsg.getUserData()); if (mlirLogicalResultIsFailure(status)) - throw py::value_error(std::string(errorMsg.join())); + throw nb::value_error(errorMsg.join().c_str()); }, "pipeline"_a, "Add textual pipeline elements to the pass manager. Throws a " diff --git a/mlir/lib/Bindings/Python/Pass.h b/mlir/lib/Bindings/Python/Pass.h index 3a500d5e8..bc4094352 100644 --- a/mlir/lib/Bindings/Python/Pass.h +++ b/mlir/lib/Bindings/Python/Pass.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_PASS_H #define MLIR_BINDINGS_PYTHON_PASS_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populatePassManagerSubmodule(pybind11::module &m); +void populatePassManagerSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index 1d8128be9..b2c1de4be 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,14 +8,16 @@ #include "Rewrite.h" +#include + #include "IRModule.h" #include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Rewrite.h" #include "mlir/Config/mlir-config.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; -using namespace py::literals; +using namespace nb::literals; using namespace mlir::python; namespace { @@ -54,18 +56,17 @@ class PyFrozenRewritePatternSet { } MlirFrozenRewritePatternSet get() { return set; } - pybind11::object getCapsule() { - return py::reinterpret_steal( + nb::object getCapsule() { + return nb::steal( mlirPythonFrozenRewritePatternSetToCapsule(get())); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirFrozenRewritePatternSet rawPm = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); if (rawPm.ptr == nullptr) - throw py::error_already_set(); - return py::cast(PyFrozenRewritePatternSet(rawPm), - py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyFrozenRewritePatternSet(rawPm), nb::rv_policy::move); } private: @@ -75,25 +76,27 @@ class PyFrozenRewritePatternSet { } // namespace /// Create the `mlir.rewrite` here. -void mlir::python::populateRewriteSubmodule(py::module &m) { +void mlir::python::populateRewriteSubmodule(nb::module_ &m) { //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- #if MLIR_ENABLE_PDL_IN_PATTERNMATCH - py::class_(m, "PDLModule", py::module_local()) - .def(py::init<>([](MlirModule module) { - return mlirPDLPatternModuleFromModule(module); - }), - "module"_a, "Create a PDL module from the given module.") + nb::class_(m, "PDLModule") + .def( + "__init__", + [](PyPDLPatternModule &self, MlirModule module) { + new (&self) + PyPDLPatternModule(mlirPDLPatternModuleFromModule(module)); + }, + "module"_a, "Create a PDL module from the given module.") .def("freeze", [](PyPDLPatternModule &self) { return new PyFrozenRewritePatternSet(mlirFreezeRewritePattern( mlirRewritePatternSetFromPDLPatternModule(self.get()))); }); -#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCg - py::class_(m, "FrozenRewritePatternSet", - py::module_local()) - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyFrozenRewritePatternSet::getCapsule) +#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH + nb::class_(m, "FrozenRewritePatternSet") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, + &PyFrozenRewritePatternSet::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyFrozenRewritePatternSet::createFromCapsule); m.def( @@ -102,7 +105,7 @@ void mlir::python::populateRewriteSubmodule(py::module &m) { auto status = mlirApplyPatternsAndFoldGreedily(module, set, {}); if (mlirLogicalResultIsFailure(status)) // FIXME: Not sure this is the right error to throw here. - throw py::value_error("pattern application failed to converge"); + throw nb::value_error("pattern application failed to converge"); }, "module"_a, "set"_a, "Applys the given patterns to the given module greedily while folding " diff --git a/mlir/lib/Bindings/Python/Rewrite.h b/mlir/lib/Bindings/Python/Rewrite.h index 997b80add..ae89e2b95 100644 --- a/mlir/lib/Bindings/Python/Rewrite.h +++ b/mlir/lib/Bindings/Python/Rewrite.h @@ -9,12 +9,12 @@ #ifndef MLIR_BINDINGS_PYTHON_REWRITE_H #define MLIR_BINDINGS_PYTHON_REWRITE_H -#include "PybindUtils.h" +#include "NanobindUtils.h" namespace mlir { namespace python { -void populateRewriteSubmodule(pybind11::module &m); +void populateRewriteSubmodule(nanobind::module_ &m); } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 10866c11b..6d6b98312 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -448,6 +448,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core MODULE_NAME _mlir ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES MainModule.cpp IRAffine.cpp @@ -463,7 +464,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Core Globals.h IRModule.h Pass.h - PybindUtils.h + NanobindUtils.h Rewrite.h PRIVATE_LINK_LIBS LLVMSupport @@ -475,6 +476,9 @@ declare_mlir_python_extension(MLIRPythonExtension.Core # Dialects MLIRCAPIFunc ) +if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL) +set_target_properties(MLIRPythonExtension.Core PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic") +endif() # This extension exposes an API to register all dialects, extensions, and passes # packaged in upstream MLIR and it is used for the upstream "mlir" Python @@ -731,6 +735,9 @@ if(MLIR_INCLUDE_TESTS) EMBED_CAPI_LINK_LIBS MLIRCAPIPythonTestDialect ) + if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL) + set_target_properties(MLIRPythonTestSources.PythonTestExtensionNanobind PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic") + endif() endif() ################################################################################ diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index ab8a91229..f240d6ef9 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -1,4 +1,4 @@ -nanobind>=2.0, <3.0 +nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 From 6ec34773872c772a1342aa7713c594a2fe9e4b36 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Fri, 20 Dec 2024 08:15:48 -0800 Subject: [PATCH 811/915] [mlir] Enable decoupling two kinds of greedy behavior. (#104649) The greedy rewriter is used in many different flows and it has a lot of convenience (work list management, debugging actions, tracing, etc). But it combines two kinds of greedy behavior 1) how ops are matched, 2) folding wherever it can. These are independent forms of greedy and leads to inefficiency. E.g., cases where one need to create different phases in lowering and is required to applying patterns in specific order split across different passes. Using the driver one ends up needlessly retrying folding/having multiple rounds of folding attempts, where one final run would have sufficed. Of course folks can locally avoid this behavior by just building their own, but this is also a common requested feature that folks keep on working around locally in suboptimal ways. For downstream users, there should be no behavioral change. Updating from the deprecated should just be a find and replace (e.g., `find ./ -type f -exec sed -i 's|applyPatternsAndFoldGreedily|applyPatternsGreedily|g' {} \;` variety) as the API arguments hasn't changed between the two. --- mlir/lib/CAPI/Transforms/Rewrite.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index 379f09cf5..c4717ca61 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -289,8 +289,7 @@ MlirLogicalResult mlirApplyPatternsAndFoldGreedily(MlirModule op, MlirFrozenRewritePatternSet patterns, MlirGreedyRewriteDriverConfig) { - return wrap( - mlir::applyPatternsAndFoldGreedily(unwrap(op), *unwrap(patterns))); + return wrap(mlir::applyPatternsGreedily(unwrap(op), *unwrap(patterns))); } //===----------------------------------------------------------------------===// From 9909fb22e14dd347aae3f30cb0b747db58fe32bb Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 20 Dec 2024 23:32:32 -0500 Subject: [PATCH 812/915] [mlir python] Port in-tree dialects to nanobind. (#119924) This is a companion to #118583, although it can be landed independently because since #117922 dialects do not have to use the same Python binding framework as the Python core code. This PR ports all of the in-tree dialect and pass extensions to nanobind, with the exception of those that remain for testing pybind11 support. This PR also: * removes CollectDiagnosticsToStringScope from NanobindAdaptors.h. This was overlooked in a previous PR and it is duplicated in Diagnostics.h. --------- Co-authored-by: Jacques Pienaar --- mlir/include/mlir/Bindings/Python/Nanobind.h | 37 ++++++++ .../mlir/Bindings/Python/NanobindAdaptors.h | 40 +-------- mlir/lib/Bindings/Python/AsyncPasses.cpp | 5 +- mlir/lib/Bindings/Python/DialectGPU.cpp | 44 +++++----- mlir/lib/Bindings/Python/DialectLLVM.cpp | 54 ++++++------ mlir/lib/Bindings/Python/DialectLinalg.cpp | 11 +-- mlir/lib/Bindings/Python/DialectNVGPU.cpp | 20 ++--- mlir/lib/Bindings/Python/DialectPDL.cpp | 43 +++++----- mlir/lib/Bindings/Python/DialectQuant.cpp | 79 +++++++++-------- .../Bindings/Python/DialectSparseTensor.cpp | 45 +++++----- mlir/lib/Bindings/Python/DialectTransform.cpp | 48 +++++------ .../Bindings/Python/ExecutionEngineModule.cpp | 85 ++++++++++--------- mlir/lib/Bindings/Python/GPUPasses.cpp | 5 +- mlir/lib/Bindings/Python/IRAffine.cpp | 7 +- mlir/lib/Bindings/Python/IRAttributes.cpp | 8 +- mlir/lib/Bindings/Python/IRCore.cpp | 10 +-- mlir/lib/Bindings/Python/IRInterfaces.cpp | 5 +- mlir/lib/Bindings/Python/IRModule.cpp | 6 +- mlir/lib/Bindings/Python/IRModule.h | 4 +- mlir/lib/Bindings/Python/IRTypes.cpp | 6 -- mlir/lib/Bindings/Python/LinalgPasses.cpp | 4 +- mlir/lib/Bindings/Python/MainModule.cpp | 3 +- mlir/lib/Bindings/Python/NanobindUtils.h | 5 +- mlir/lib/Bindings/Python/Pass.cpp | 7 +- .../Bindings/Python/RegisterEverything.cpp | 5 +- mlir/lib/Bindings/Python/Rewrite.cpp | 5 +- .../Bindings/Python/SparseTensorPasses.cpp | 4 +- .../Bindings/Python/TransformInterpreter.cpp | 44 +++++----- mlir/python/CMakeLists.txt | 22 +++-- 29 files changed, 318 insertions(+), 343 deletions(-) create mode 100644 mlir/include/mlir/Bindings/Python/Nanobind.h diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h new file mode 100644 index 000000000..ca942c83d --- /dev/null +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -0,0 +1,37 @@ +//===- Nanobind.h - Trampoline header with ignored warnings ---------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This file is a trampoline for the nanobind headers while disabling warnings +// reported by the LLVM/MLIR build. This file avoids adding complexity build +// system side. +//===----------------------------------------------------------------------===// + +#ifndef MLIR_BINDINGS_PYTHON_NANOBIND_H +#define MLIR_BINDINGS_PYTHON_NANOBIND_H + +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wzero-length-array" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wnested-anon-types" +#pragma GCC diagnostic ignored "-Wc++98-compat-extra-semi" +#pragma GCC diagnostic ignored "-Wcovered-switch-default" +#endif +#include +#include +#include +#include +#include +#include +#include +#include +#include +#if defined(__clang__) || defined(__GNUC__) +#pragma GCC diagnostic pop +#endif + +#endif // MLIR_BINDINGS_PYTHON_NANOBIND_H diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 943981b1f..517351cac 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -19,14 +19,12 @@ #ifndef MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H -#include -#include - #include -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "llvm/ADT/Twine.h" // Raw CAPI type casters need to be declared before use, so always include them @@ -631,40 +629,6 @@ class mlir_value_subclass : public pure_subclass { } // namespace nanobind_adaptors -/// RAII scope intercepting all diagnostics into a string. The message must be -/// checked before this goes out of scope. -class CollectDiagnosticsToStringScope { -public: - explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); - } - ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); - mlirContextDetachDiagnosticHandler(context, handlerID); - } - - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } - -private: - static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { - auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); - }; - MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; - mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; - mlirDiagnosticPrint(diag, printer, data); - return mlirLogicalResultSuccess(); - } - - MlirContext context; - MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; -}; - } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/AsyncPasses.cpp b/mlir/lib/Bindings/Python/AsyncPasses.cpp index b611a758d..cfb8dcaaa 100644 --- a/mlir/lib/Bindings/Python/AsyncPasses.cpp +++ b/mlir/lib/Bindings/Python/AsyncPasses.cpp @@ -8,14 +8,13 @@ #include "mlir-c/Dialect/Async.h" -#include -#include +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirAsyncPasses, m) { +NB_MODULE(_mlirAsyncPasses, m) { m.doc() = "MLIR Async Dialect Passes"; // Register all Async passes on load. diff --git a/mlir/lib/Bindings/Python/DialectGPU.cpp b/mlir/lib/Bindings/Python/DialectGPU.cpp index 560a54bcd..e5045cf0b 100644 --- a/mlir/lib/Bindings/Python/DialectGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectGPU.cpp @@ -9,21 +9,21 @@ #include "mlir-c/Dialect/GPU.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -#include -#include +namespace nb = nanobind; +using namespace nanobind::literals; -namespace py = pybind11; using namespace mlir; using namespace mlir::python; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirDialectsGPU, m) { +NB_MODULE(_mlirDialectsGPU, m) { m.doc() = "MLIR GPU Dialect"; //===-------------------------------------------------------------------===// // AsyncTokenType @@ -34,11 +34,11 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { mlirGPUAsyncTokenType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirGPUAsyncTokenTypeGet(ctx)); }, - "Gets an instance of AsyncTokenType in the same context", py::arg("cls"), - py::arg("ctx") = py::none()); + "Gets an instance of AsyncTokenType in the same context", nb::arg("cls"), + nb::arg("ctx").none() = nb::none()); //===-------------------------------------------------------------------===// // ObjectAttr @@ -47,12 +47,12 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { mlir_attribute_subclass(m, "ObjectAttr", mlirAttributeIsAGPUObjectAttr) .def_classmethod( "get", - [](py::object cls, MlirAttribute target, uint32_t format, - py::bytes object, std::optional mlirObjectProps, + [](nb::object cls, MlirAttribute target, uint32_t format, + nb::bytes object, std::optional mlirObjectProps, std::optional mlirKernelsAttr) { - py::buffer_info info(py::buffer(object).request()); - MlirStringRef objectStrRef = - mlirStringRefCreate(static_cast(info.ptr), info.size); + MlirStringRef objectStrRef = mlirStringRefCreate( + static_cast(const_cast(object.data())), + object.size()); return cls(mlirGPUObjectAttrGetWithKernels( mlirAttributeGetContext(target), target, format, objectStrRef, mlirObjectProps.has_value() ? *mlirObjectProps @@ -61,7 +61,7 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { : MlirAttribute{nullptr})); }, "cls"_a, "target"_a, "format"_a, "object"_a, - "properties"_a = py::none(), "kernels"_a = py::none(), + "properties"_a.none() = nb::none(), "kernels"_a.none() = nb::none(), "Gets a gpu.object from parameters.") .def_property_readonly( "target", @@ -73,18 +73,18 @@ PYBIND11_MODULE(_mlirDialectsGPU, m) { "object", [](MlirAttribute self) { MlirStringRef stringRef = mlirGPUObjectAttrGetObject(self); - return py::bytes(stringRef.data, stringRef.length); + return nb::bytes(stringRef.data, stringRef.length); }) .def_property_readonly("properties", - [](MlirAttribute self) { + [](MlirAttribute self) -> nb::object { if (mlirGPUObjectAttrHasProperties(self)) - return py::cast( + return nb::cast( mlirGPUObjectAttrGetProperties(self)); - return py::none().cast(); + return nb::none(); }) - .def_property_readonly("kernels", [](MlirAttribute self) { + .def_property_readonly("kernels", [](MlirAttribute self) -> nb::object { if (mlirGPUObjectAttrHasKernels(self)) - return py::cast(mlirGPUObjectAttrGetKernels(self)); - return py::none().cast(); + return nb::cast(mlirGPUObjectAttrGetKernels(self)); + return nb::none(); }); } diff --git a/mlir/lib/Bindings/Python/DialectLLVM.cpp b/mlir/lib/Bindings/Python/DialectLLVM.cpp index cccf1370b..f211e769d 100644 --- a/mlir/lib/Bindings/Python/DialectLLVM.cpp +++ b/mlir/lib/Bindings/Python/DialectLLVM.cpp @@ -12,15 +12,19 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Diagnostics.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; -namespace py = pybind11; using namespace llvm; using namespace mlir; using namespace mlir::python; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -void populateDialectLLVMSubmodule(const pybind11::module &m) { +void populateDialectLLVMSubmodule(const nanobind::module_ &m) { //===--------------------------------------------------------------------===// // StructType @@ -31,35 +35,35 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { llvmStructType.def_classmethod( "get_literal", - [](py::object cls, const std::vector &elements, bool packed, + [](nb::object cls, const std::vector &elements, bool packed, MlirLocation loc) { CollectDiagnosticsToStringScope scope(mlirLocationGetContext(loc)); MlirType type = mlirLLVMStructTypeLiteralGetChecked( loc, elements.size(), elements.data(), packed); if (mlirTypeIsNull(type)) { - throw py::value_error(scope.takeMessage()); + throw nb::value_error(scope.takeMessage().c_str()); } return cls(type); }, - "cls"_a, "elements"_a, py::kw_only(), "packed"_a = false, - "loc"_a = py::none()); + "cls"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "loc"_a.none() = nb::none()); llvmStructType.def_classmethod( "get_identified", - [](py::object cls, const std::string &name, MlirContext context) { + [](nb::object cls, const std::string &name, MlirContext context) { return cls(mlirLLVMStructTypeIdentifiedGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - "cls"_a, "name"_a, py::kw_only(), "context"_a = py::none()); + "cls"_a, "name"_a, nb::kw_only(), "context"_a.none() = nb::none()); llvmStructType.def_classmethod( "get_opaque", - [](py::object cls, const std::string &name, MlirContext context) { + [](nb::object cls, const std::string &name, MlirContext context) { return cls(mlirLLVMStructTypeOpaqueGet( context, mlirStringRefCreate(name.data(), name.size()))); }, - "cls"_a, "name"_a, "context"_a = py::none()); + "cls"_a, "name"_a, "context"_a.none() = nb::none()); llvmStructType.def( "set_body", @@ -67,22 +71,22 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { MlirLogicalResult result = mlirLLVMStructTypeSetBody( self, elements.size(), elements.data(), packed); if (!mlirLogicalResultIsSuccess(result)) { - throw py::value_error( + throw nb::value_error( "Struct body already set to different content."); } }, - "elements"_a, py::kw_only(), "packed"_a = false); + "elements"_a, nb::kw_only(), "packed"_a = false); llvmStructType.def_classmethod( "new_identified", - [](py::object cls, const std::string &name, + [](nb::object cls, const std::string &name, const std::vector &elements, bool packed, MlirContext ctx) { return cls(mlirLLVMStructTypeIdentifiedNewGet( ctx, mlirStringRefCreate(name.data(), name.length()), elements.size(), elements.data(), packed)); }, - "cls"_a, "name"_a, "elements"_a, py::kw_only(), "packed"_a = false, - "context"_a = py::none()); + "cls"_a, "name"_a, "elements"_a, nb::kw_only(), "packed"_a = false, + "context"_a.none() = nb::none()); llvmStructType.def_property_readonly( "name", [](MlirType type) -> std::optional { @@ -93,12 +97,12 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { return StringRef(stringRef.data, stringRef.length).str(); }); - llvmStructType.def_property_readonly("body", [](MlirType type) -> py::object { + llvmStructType.def_property_readonly("body", [](MlirType type) -> nb::object { // Don't crash in absence of a body. if (mlirLLVMStructTypeIsOpaque(type)) - return py::none(); + return nb::none(); - py::list body; + nb::list body; for (intptr_t i = 0, e = mlirLLVMStructTypeGetNumElementTypes(type); i < e; ++i) { body.append(mlirLLVMStructTypeGetElementType(type, i)); @@ -119,24 +123,24 @@ void populateDialectLLVMSubmodule(const pybind11::module &m) { mlir_type_subclass(m, "PointerType", mlirTypeIsALLVMPointerType) .def_classmethod( "get", - [](py::object cls, std::optional addressSpace, + [](nb::object cls, std::optional addressSpace, MlirContext context) { CollectDiagnosticsToStringScope scope(context); MlirType type = mlirLLVMPointerTypeGet( context, addressSpace.has_value() ? *addressSpace : 0); if (mlirTypeIsNull(type)) { - throw py::value_error(scope.takeMessage()); + throw nb::value_error(scope.takeMessage().c_str()); } return cls(type); }, - "cls"_a, "address_space"_a = py::none(), py::kw_only(), - "context"_a = py::none()) + "cls"_a, "address_space"_a.none() = nb::none(), nb::kw_only(), + "context"_a.none() = nb::none()) .def_property_readonly("address_space", [](MlirType type) { return mlirLLVMPointerTypeGetAddressSpace(type); }); } -PYBIND11_MODULE(_mlirDialectsLLVM, m) { +NB_MODULE(_mlirDialectsLLVM, m) { m.doc() = "MLIR LLVM Dialect"; populateDialectLLVMSubmodule(m); diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 2e54ebeb6..548df4ee1 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -8,20 +8,21 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; -static void populateDialectLinalgSubmodule(py::module m) { +static void populateDialectLinalgSubmodule(nb::module_ m) { m.def( "fill_builtin_region", [](MlirOperation op) { mlirLinalgFillBuiltinNamedOpRegion(op); }, - py::arg("op"), + nb::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); } -PYBIND11_MODULE(_mlirDialectsLinalg, m) { +NB_MODULE(_mlirDialectsLinalg, m) { m.doc() = "MLIR Linalg dialect."; populateDialectLinalgSubmodule(m); diff --git a/mlir/lib/Bindings/Python/DialectNVGPU.cpp b/mlir/lib/Bindings/Python/DialectNVGPU.cpp index 754e0a75b..a0d6a4b4c 100644 --- a/mlir/lib/Bindings/Python/DialectNVGPU.cpp +++ b/mlir/lib/Bindings/Python/DialectNVGPU.cpp @@ -8,33 +8,33 @@ #include "mlir-c/Dialect/NVGPU.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace llvm; using namespace mlir; using namespace mlir::python; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -static void populateDialectNVGPUSubmodule(const pybind11::module &m) { +static void populateDialectNVGPUSubmodule(const nb::module_ &m) { auto nvgpuTensorMapDescriptorType = mlir_type_subclass( m, "TensorMapDescriptorType", mlirTypeIsANVGPUTensorMapDescriptorType); nvgpuTensorMapDescriptorType.def_classmethod( "get", - [](py::object cls, MlirType tensorMemrefType, int swizzle, int l2promo, + [](nb::object cls, MlirType tensorMemrefType, int swizzle, int l2promo, int oobFill, int interleave, MlirContext ctx) { return cls(mlirNVGPUTensorMapDescriptorTypeGet( ctx, tensorMemrefType, swizzle, l2promo, oobFill, interleave)); }, "Gets an instance of TensorMapDescriptorType in the same context", - py::arg("cls"), py::arg("tensor_type"), py::arg("swizzle"), - py::arg("l2promo"), py::arg("oob_fill"), py::arg("interleave"), - py::arg("ctx") = py::none()); + nb::arg("cls"), nb::arg("tensor_type"), nb::arg("swizzle"), + nb::arg("l2promo"), nb::arg("oob_fill"), nb::arg("interleave"), + nb::arg("ctx").none() = nb::none()); } -PYBIND11_MODULE(_mlirDialectsNVGPU, m) { +NB_MODULE(_mlirDialectsNVGPU, m) { m.doc() = "MLIR NVGPU dialect."; populateDialectNVGPUSubmodule(m); diff --git a/mlir/lib/Bindings/Python/DialectPDL.cpp b/mlir/lib/Bindings/Python/DialectPDL.cpp index 8d3f9a7ab..bcc6ff406 100644 --- a/mlir/lib/Bindings/Python/DialectPDL.cpp +++ b/mlir/lib/Bindings/Python/DialectPDL.cpp @@ -8,19 +8,16 @@ #include "mlir-c/Dialect/PDL.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include -#include -#include +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace llvm; using namespace mlir; using namespace mlir::python; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -void populateDialectPDLSubmodule(const pybind11::module &m) { +void populateDialectPDLSubmodule(const nanobind::module_ &m) { //===-------------------------------------------------------------------===// // PDLType //===-------------------------------------------------------------------===// @@ -35,11 +32,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) { mlir_type_subclass(m, "AttributeType", mlirTypeIsAPDLAttributeType); attributeType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirPDLAttributeTypeGet(ctx)); }, - "Get an instance of AttributeType in given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of AttributeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // OperationType @@ -49,11 +46,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) { mlir_type_subclass(m, "OperationType", mlirTypeIsAPDLOperationType); operationType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirPDLOperationTypeGet(ctx)); }, - "Get an instance of OperationType in given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of OperationType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // RangeType @@ -62,12 +59,12 @@ void populateDialectPDLSubmodule(const pybind11::module &m) { auto rangeType = mlir_type_subclass(m, "RangeType", mlirTypeIsAPDLRangeType); rangeType.def_classmethod( "get", - [](py::object cls, MlirType elementType) { + [](nb::object cls, MlirType elementType) { return cls(mlirPDLRangeTypeGet(elementType)); }, "Gets an instance of RangeType in the same context as the provided " "element type.", - py::arg("cls"), py::arg("element_type")); + nb::arg("cls"), nb::arg("element_type")); rangeType.def_property_readonly( "element_type", [](MlirType type) { return mlirPDLRangeTypeGetElementType(type); }, @@ -80,11 +77,11 @@ void populateDialectPDLSubmodule(const pybind11::module &m) { auto typeType = mlir_type_subclass(m, "TypeType", mlirTypeIsAPDLTypeType); typeType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirPDLTypeTypeGet(ctx)); }, - "Get an instance of TypeType in given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of TypeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // ValueType @@ -93,14 +90,14 @@ void populateDialectPDLSubmodule(const pybind11::module &m) { auto valueType = mlir_type_subclass(m, "ValueType", mlirTypeIsAPDLValueType); valueType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirPDLValueTypeGet(ctx)); }, - "Get an instance of TypeType in given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of TypeType in given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); } -PYBIND11_MODULE(_mlirDialectsPDL, m) { +NB_MODULE(_mlirDialectsPDL, m) { m.doc() = "MLIR PDL dialect."; populateDialectPDLSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 9a871f2c1..29f19c9c5 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -6,21 +6,20 @@ // //===----------------------------------------------------------------------===// -#include "mlir-c/Dialect/Quant.h" -#include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" #include -#include -#include -#include #include -namespace py = pybind11; +#include "mlir-c/Dialect/Quant.h" +#include "mlir-c/IR.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" + +namespace nb = nanobind; using namespace llvm; using namespace mlir; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -static void populateDialectQuantSubmodule(const py::module &m) { +static void populateDialectQuantSubmodule(const nb::module_ &m) { //===-------------------------------------------------------------------===// // QuantizedType //===-------------------------------------------------------------------===// @@ -35,7 +34,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Default minimum value for the integer with the specified signedness and " "bit width.", - py::arg("is_signed"), py::arg("integral_width")); + nb::arg("is_signed"), nb::arg("integral_width")); quantizedType.def_staticmethod( "default_maximum_for_integer", [](bool isSigned, unsigned integralWidth) { @@ -44,7 +43,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Default maximum value for the integer with the specified signedness and " "bit width.", - py::arg("is_signed"), py::arg("integral_width")); + nb::arg("is_signed"), nb::arg("integral_width")); quantizedType.def_property_readonly( "expressed_type", [](MlirType type) { return mlirQuantizedTypeGetExpressedType(type); }, @@ -82,7 +81,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Checks whether the candidate type can be expressed by this quantized " "type.", - py::arg("candidate")); + nb::arg("candidate")); quantizedType.def_property_readonly( "quantized_element_type", [](MlirType type) { @@ -96,24 +95,24 @@ static void populateDialectQuantSubmodule(const py::module &m) { mlirQuantizedTypeCastFromStorageType(type, candidate); if (!mlirTypeIsNull(castResult)) return castResult; - throw py::type_error("Invalid cast."); + throw nb::type_error("Invalid cast."); }, "Casts from a type based on the storage type of this quantized type to a " "corresponding type based on the quantized type. Raises TypeError if the " "cast is not valid.", - py::arg("candidate")); + nb::arg("candidate")); quantizedType.def_staticmethod( "cast_to_storage_type", [](MlirType type) { MlirType castResult = mlirQuantizedTypeCastToStorageType(type); if (!mlirTypeIsNull(castResult)) return castResult; - throw py::type_error("Invalid cast."); + throw nb::type_error("Invalid cast."); }, "Casts from a type based on a quantized type to a corresponding type " "based on the storage type of this quantized type. Raises TypeError if " "the cast is not valid.", - py::arg("type")); + nb::arg("type")); quantizedType.def( "cast_from_expressed_type", [](MlirType type, MlirType candidate) { @@ -121,24 +120,24 @@ static void populateDialectQuantSubmodule(const py::module &m) { mlirQuantizedTypeCastFromExpressedType(type, candidate); if (!mlirTypeIsNull(castResult)) return castResult; - throw py::type_error("Invalid cast."); + throw nb::type_error("Invalid cast."); }, "Casts from a type based on the expressed type of this quantized type to " "a corresponding type based on the quantized type. Raises TypeError if " "the cast is not valid.", - py::arg("candidate")); + nb::arg("candidate")); quantizedType.def_staticmethod( "cast_to_expressed_type", [](MlirType type) { MlirType castResult = mlirQuantizedTypeCastToExpressedType(type); if (!mlirTypeIsNull(castResult)) return castResult; - throw py::type_error("Invalid cast."); + throw nb::type_error("Invalid cast."); }, "Casts from a type based on a quantized type to a corresponding type " "based on the expressed type of this quantized type. Raises TypeError if " "the cast is not valid.", - py::arg("type")); + nb::arg("type")); quantizedType.def( "cast_expressed_to_storage_type", [](MlirType type, MlirType candidate) { @@ -146,12 +145,12 @@ static void populateDialectQuantSubmodule(const py::module &m) { mlirQuantizedTypeCastExpressedToStorageType(type, candidate); if (!mlirTypeIsNull(castResult)) return castResult; - throw py::type_error("Invalid cast."); + throw nb::type_error("Invalid cast."); }, "Casts from a type based on the expressed type of this quantized type to " "a corresponding type based on the storage type. Raises TypeError if the " "cast is not valid.", - py::arg("candidate")); + nb::arg("candidate")); quantizedType.get_class().attr("FLAG_SIGNED") = mlirQuantizedTypeGetSignedFlag(); @@ -165,7 +164,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { quantizedType.get_class()); anyQuantizedType.def_classmethod( "get", - [](py::object cls, unsigned flags, MlirType storageType, + [](nb::object cls, unsigned flags, MlirType storageType, MlirType expressedType, int64_t storageTypeMin, int64_t storageTypeMax) { return cls(mlirAnyQuantizedTypeGet(flags, storageType, expressedType, @@ -173,9 +172,9 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Gets an instance of AnyQuantizedType in the same context as the " "provided storage type.", - py::arg("cls"), py::arg("flags"), py::arg("storage_type"), - py::arg("expressed_type"), py::arg("storage_type_min"), - py::arg("storage_type_max")); + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("storage_type_min"), + nb::arg("storage_type_max")); //===-------------------------------------------------------------------===// // UniformQuantizedType @@ -186,7 +185,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { quantizedType.get_class()); uniformQuantizedType.def_classmethod( "get", - [](py::object cls, unsigned flags, MlirType storageType, + [](nb::object cls, unsigned flags, MlirType storageType, MlirType expressedType, double scale, int64_t zeroPoint, int64_t storageTypeMin, int64_t storageTypeMax) { return cls(mlirUniformQuantizedTypeGet(flags, storageType, @@ -195,9 +194,9 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Gets an instance of UniformQuantizedType in the same context as the " "provided storage type.", - py::arg("cls"), py::arg("flags"), py::arg("storage_type"), - py::arg("expressed_type"), py::arg("scale"), py::arg("zero_point"), - py::arg("storage_type_min"), py::arg("storage_type_max")); + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scale"), nb::arg("zero_point"), + nb::arg("storage_type_min"), nb::arg("storage_type_max")); uniformQuantizedType.def_property_readonly( "scale", [](MlirType type) { return mlirUniformQuantizedTypeGetScale(type); }, @@ -221,12 +220,12 @@ static void populateDialectQuantSubmodule(const py::module &m) { quantizedType.get_class()); uniformQuantizedPerAxisType.def_classmethod( "get", - [](py::object cls, unsigned flags, MlirType storageType, + [](nb::object cls, unsigned flags, MlirType storageType, MlirType expressedType, std::vector scales, std::vector zeroPoints, int32_t quantizedDimension, int64_t storageTypeMin, int64_t storageTypeMax) { if (scales.size() != zeroPoints.size()) - throw py::value_error( + throw nb::value_error( "Mismatching number of scales and zero points."); auto nDims = static_cast(scales.size()); return cls(mlirUniformQuantizedPerAxisTypeGet( @@ -236,10 +235,10 @@ static void populateDialectQuantSubmodule(const py::module &m) { }, "Gets an instance of UniformQuantizedPerAxisType in the same context as " "the provided storage type.", - py::arg("cls"), py::arg("flags"), py::arg("storage_type"), - py::arg("expressed_type"), py::arg("scales"), py::arg("zero_points"), - py::arg("quantized_dimension"), py::arg("storage_type_min"), - py::arg("storage_type_max")); + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimension"), nb::arg("storage_type_min"), + nb::arg("storage_type_max")); uniformQuantizedPerAxisType.def_property_readonly( "scales", [](MlirType type) { @@ -294,13 +293,13 @@ static void populateDialectQuantSubmodule(const py::module &m) { quantizedType.get_class()); calibratedQuantizedType.def_classmethod( "get", - [](py::object cls, MlirType expressedType, double min, double max) { + [](nb::object cls, MlirType expressedType, double min, double max) { return cls(mlirCalibratedQuantizedTypeGet(expressedType, min, max)); }, "Gets an instance of CalibratedQuantizedType in the same context as the " "provided expressed type.", - py::arg("cls"), py::arg("expressed_type"), py::arg("min"), - py::arg("max")); + nb::arg("cls"), nb::arg("expressed_type"), nb::arg("min"), + nb::arg("max")); calibratedQuantizedType.def_property_readonly("min", [](MlirType type) { return mlirCalibratedQuantizedTypeGetMin(type); }); @@ -309,7 +308,7 @@ static void populateDialectQuantSubmodule(const py::module &m) { }); } -PYBIND11_MODULE(_mlirDialectsQuant, m) { +NB_MODULE(_mlirDialectsQuant, m) { m.doc() = "MLIR Quantization dialect"; populateDialectQuantSubmodule(m); diff --git a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp index a730bf500..97cebccee 100644 --- a/mlir/lib/Bindings/Python/DialectSparseTensor.cpp +++ b/mlir/lib/Bindings/Python/DialectSparseTensor.cpp @@ -6,32 +6,30 @@ // //===----------------------------------------------------------------------===// +#include +#include + #include "mlir-c/AffineMap.h" #include "mlir-c/Dialect/SparseTensor.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include -#include -#include -#include -#include +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace llvm; using namespace mlir; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -static void populateDialectSparseTensorSubmodule(const py::module &m) { - py::enum_(m, "LevelFormat", py::module_local()) +static void populateDialectSparseTensorSubmodule(const nb::module_ &m) { + nb::enum_(m, "LevelFormat", nb::is_arithmetic(), + nb::is_flag()) .value("dense", MLIR_SPARSE_TENSOR_LEVEL_DENSE) .value("n_out_of_m", MLIR_SPARSE_TENSOR_LEVEL_N_OUT_OF_M) .value("compressed", MLIR_SPARSE_TENSOR_LEVEL_COMPRESSED) .value("singleton", MLIR_SPARSE_TENSOR_LEVEL_SINGLETON) .value("loose_compressed", MLIR_SPARSE_TENSOR_LEVEL_LOOSE_COMPRESSED); - py::enum_(m, "LevelProperty", - py::module_local()) + nb::enum_(m, "LevelProperty") .value("non_ordered", MLIR_SPARSE_PROPERTY_NON_ORDERED) .value("non_unique", MLIR_SPARSE_PROPERTY_NON_UNIQUE) .value("soa", MLIR_SPARSE_PROPERTY_SOA); @@ -40,7 +38,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { mlirAttributeIsASparseTensorEncodingAttr) .def_classmethod( "get", - [](py::object cls, std::vector lvlTypes, + [](nb::object cls, std::vector lvlTypes, std::optional dimToLvl, std::optional lvlToDim, int posWidth, int crdWidth, std::optional explicitVal, @@ -52,24 +50,25 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { crdWidth, explicitVal ? *explicitVal : MlirAttribute{nullptr}, implicitVal ? *implicitVal : MlirAttribute{nullptr})); }, - py::arg("cls"), py::arg("lvl_types"), py::arg("dim_to_lvl"), - py::arg("lvl_to_dim"), py::arg("pos_width"), py::arg("crd_width"), - py::arg("explicit_val") = py::none(), - py::arg("implicit_val") = py::none(), py::arg("context") = py::none(), + nb::arg("cls"), nb::arg("lvl_types"), nb::arg("dim_to_lvl").none(), + nb::arg("lvl_to_dim").none(), nb::arg("pos_width"), + nb::arg("crd_width"), nb::arg("explicit_val").none() = nb::none(), + nb::arg("implicit_val").none() = nb::none(), + nb::arg("context").none() = nb::none(), "Gets a sparse_tensor.encoding from parameters.") .def_classmethod( "build_level_type", - [](py::object cls, MlirSparseTensorLevelFormat lvlFmt, + [](nb::object cls, MlirSparseTensorLevelFormat lvlFmt, const std::vector &properties, unsigned n, unsigned m) { return mlirSparseTensorEncodingAttrBuildLvlType( lvlFmt, properties.data(), properties.size(), n, m); }, - py::arg("cls"), py::arg("lvl_fmt"), - py::arg("properties") = + nb::arg("cls"), nb::arg("lvl_fmt"), + nb::arg("properties") = std::vector(), - py::arg("n") = 0, py::arg("m") = 0, + nb::arg("n") = 0, nb::arg("m") = 0, "Builds a sparse_tensor.encoding.level_type from parameters.") .def_property_readonly( "lvl_types", @@ -143,7 +142,7 @@ static void populateDialectSparseTensorSubmodule(const py::module &m) { }); } -PYBIND11_MODULE(_mlirDialectsSparseTensor, m) { +NB_MODULE(_mlirDialectsSparseTensor, m) { m.doc() = "MLIR SparseTensor dialect."; populateDialectSparseTensorSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/DialectTransform.cpp b/mlir/lib/Bindings/Python/DialectTransform.cpp index 6b57e652a..59a030ac6 100644 --- a/mlir/lib/Bindings/Python/DialectTransform.cpp +++ b/mlir/lib/Bindings/Python/DialectTransform.cpp @@ -6,22 +6,20 @@ // //===----------------------------------------------------------------------===// +#include + #include "mlir-c/Dialect/Transform.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" -#include -#include -#include -#include -#include +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; -using namespace mlir::python::adaptors; +using namespace mlir::python::nanobind_adaptors; -void populateDialectTransformSubmodule(const pybind11::module &m) { +void populateDialectTransformSubmodule(const nb::module_ &m) { //===-------------------------------------------------------------------===// // AnyOpType //===-------------------------------------------------------------------===// @@ -31,11 +29,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { mlirTransformAnyOpTypeGetTypeID); anyOpType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirTransformAnyOpTypeGet(ctx)); }, - "Get an instance of AnyOpType in the given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of AnyOpType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // AnyParamType @@ -46,11 +44,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { mlirTransformAnyParamTypeGetTypeID); anyParamType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirTransformAnyParamTypeGet(ctx)); }, - "Get an instance of AnyParamType in the given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of AnyParamType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // AnyValueType @@ -61,11 +59,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { mlirTransformAnyValueTypeGetTypeID); anyValueType.def_classmethod( "get", - [](py::object cls, MlirContext ctx) { + [](nb::object cls, MlirContext ctx) { return cls(mlirTransformAnyValueTypeGet(ctx)); }, - "Get an instance of AnyValueType in the given context.", py::arg("cls"), - py::arg("context") = py::none()); + "Get an instance of AnyValueType in the given context.", nb::arg("cls"), + nb::arg("context").none() = nb::none()); //===-------------------------------------------------------------------===// // OperationType @@ -76,21 +74,21 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { mlirTransformOperationTypeGetTypeID); operationType.def_classmethod( "get", - [](py::object cls, const std::string &operationName, MlirContext ctx) { + [](nb::object cls, const std::string &operationName, MlirContext ctx) { MlirStringRef cOperationName = mlirStringRefCreate(operationName.data(), operationName.size()); return cls(mlirTransformOperationTypeGet(ctx, cOperationName)); }, "Get an instance of OperationType for the given kind in the given " "context", - py::arg("cls"), py::arg("operation_name"), - py::arg("context") = py::none()); + nb::arg("cls"), nb::arg("operation_name"), + nb::arg("context").none() = nb::none()); operationType.def_property_readonly( "operation_name", [](MlirType type) { MlirStringRef operationName = mlirTransformOperationTypeGetOperationName(type); - return py::str(operationName.data, operationName.length); + return nb::str(operationName.data, operationName.length); }, "Get the name of the payload operation accepted by the handle."); @@ -103,11 +101,11 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { mlirTransformParamTypeGetTypeID); paramType.def_classmethod( "get", - [](py::object cls, MlirType type, MlirContext ctx) { + [](nb::object cls, MlirType type, MlirContext ctx) { return cls(mlirTransformParamTypeGet(ctx, type)); }, "Get an instance of ParamType for the given type in the given context.", - py::arg("cls"), py::arg("type"), py::arg("context") = py::none()); + nb::arg("cls"), nb::arg("type"), nb::arg("context").none() = nb::none()); paramType.def_property_readonly( "type", [](MlirType type) { @@ -117,7 +115,7 @@ void populateDialectTransformSubmodule(const pybind11::module &m) { "Get the type this ParamType is associated with."); } -PYBIND11_MODULE(_mlirDialectsTransform, m) { +NB_MODULE(_mlirDialectsTransform, m) { m.doc() = "MLIR Transform dialect."; populateDialectTransformSubmodule(m); } diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index b3df30583..81dada355 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,9 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; using namespace mlir; using namespace mlir::python; @@ -34,23 +35,22 @@ class PyExecutionEngine { executionEngine.ptr = nullptr; referencedObjects.clear(); } - pybind11::object getCapsule() { - return py::reinterpret_steal( - mlirPythonExecutionEngineToCapsule(get())); + nb::object getCapsule() { + return nb::steal(mlirPythonExecutionEngineToCapsule(get())); } // Add an object to the list of referenced objects whose lifetime must exceed // those of the ExecutionEngine. - void addReferencedObject(const pybind11::object &obj) { + void addReferencedObject(const nb::object &obj) { referencedObjects.push_back(obj); } - static pybind11::object createFromCapsule(pybind11::object capsule) { + static nb::object createFromCapsule(nb::object capsule) { MlirExecutionEngine rawPm = mlirPythonCapsuleToExecutionEngine(capsule.ptr()); if (mlirExecutionEngineIsNull(rawPm)) - throw py::error_already_set(); - return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move); + throw nb::python_error(); + return nb::cast(PyExecutionEngine(rawPm), nb::rv_policy::move); } private: @@ -58,44 +58,45 @@ class PyExecutionEngine { // We support Python ctypes closures as callbacks. Keep a list of the objects // so that they don't get garbage collected. (The ExecutionEngine itself // just holds raw pointers with no lifetime semantics). - std::vector referencedObjects; + std::vector referencedObjects; }; } // namespace /// Create the `mlir.execution_engine` module here. -PYBIND11_MODULE(_mlirExecutionEngine, m) { +NB_MODULE(_mlirExecutionEngine, m) { m.doc() = "MLIR Execution Engine"; //---------------------------------------------------------------------------- // Mapping of the top-level PassManager //---------------------------------------------------------------------------- - py::class_(m, "ExecutionEngine", py::module_local()) - .def(py::init<>([](MlirModule module, int optLevel, - const std::vector &sharedLibPaths, - bool enableObjectDump) { - llvm::SmallVector libPaths; - for (const std::string &path : sharedLibPaths) - libPaths.push_back({path.c_str(), path.length()}); - MlirExecutionEngine executionEngine = - mlirExecutionEngineCreate(module, optLevel, libPaths.size(), - libPaths.data(), enableObjectDump); - if (mlirExecutionEngineIsNull(executionEngine)) - throw std::runtime_error( - "Failure while creating the ExecutionEngine."); - return new PyExecutionEngine(executionEngine); - }), - py::arg("module"), py::arg("opt_level") = 2, - py::arg("shared_libs") = py::list(), - py::arg("enable_object_dump") = true, - "Create a new ExecutionEngine instance for the given Module. The " - "module must contain only dialects that can be translated to LLVM. " - "Perform transformations and code generation at the optimization " - "level `opt_level` if specified, or otherwise at the default " - "level of two (-O2). Load a list of libraries specified in " - "`shared_libs`.") - .def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR, - &PyExecutionEngine::getCapsule) + nb::class_(m, "ExecutionEngine") + .def( + "__init__", + [](PyExecutionEngine &self, MlirModule module, int optLevel, + const std::vector &sharedLibPaths, + bool enableObjectDump) { + llvm::SmallVector libPaths; + for (const std::string &path : sharedLibPaths) + libPaths.push_back({path.c_str(), path.length()}); + MlirExecutionEngine executionEngine = + mlirExecutionEngineCreate(module, optLevel, libPaths.size(), + libPaths.data(), enableObjectDump); + if (mlirExecutionEngineIsNull(executionEngine)) + throw std::runtime_error( + "Failure while creating the ExecutionEngine."); + new (&self) PyExecutionEngine(executionEngine); + }, + nb::arg("module"), nb::arg("opt_level") = 2, + nb::arg("shared_libs") = nb::list(), + nb::arg("enable_object_dump") = true, + "Create a new ExecutionEngine instance for the given Module. The " + "module must contain only dialects that can be translated to LLVM. " + "Perform transformations and code generation at the optimization " + "level `opt_level` if specified, or otherwise at the default " + "level of two (-O2). Load a list of libraries specified in " + "`shared_libs`.") + .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyExecutionEngine::getCapsule) .def("_testing_release", &PyExecutionEngine::release, "Releases (leaks) the backing ExecutionEngine (for testing purpose)") .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule) @@ -107,21 +108,21 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { mlirStringRefCreate(func.c_str(), func.size())); return reinterpret_cast(res); }, - py::arg("func_name"), + nb::arg("func_name"), "Lookup function `func` in the ExecutionEngine.") .def( "raw_register_runtime", [](PyExecutionEngine &executionEngine, const std::string &name, - py::object callbackObj) { + nb::object callbackObj) { executionEngine.addReferencedObject(callbackObj); uintptr_t rawSym = - py::cast(py::getattr(callbackObj, "value")); + nb::cast(nb::getattr(callbackObj, "value")); mlirExecutionEngineRegisterSymbol( executionEngine.get(), mlirStringRefCreate(name.c_str(), name.size()), reinterpret_cast(rawSym)); }, - py::arg("name"), py::arg("callback"), + nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") .def( "dump_to_object_file", @@ -130,5 +131,5 @@ PYBIND11_MODULE(_mlirExecutionEngine, m) { executionEngine.get(), mlirStringRefCreate(fileName.c_str(), fileName.size())); }, - py::arg("file_name"), "Dump ExecutionEngine to an object file."); + nb::arg("file_name"), "Dump ExecutionEngine to an object file."); } diff --git a/mlir/lib/Bindings/Python/GPUPasses.cpp b/mlir/lib/Bindings/Python/GPUPasses.cpp index e276a3ce3..be474edbe 100644 --- a/mlir/lib/Bindings/Python/GPUPasses.cpp +++ b/mlir/lib/Bindings/Python/GPUPasses.cpp @@ -8,14 +8,13 @@ #include "mlir-c/Dialect/GPU.h" -#include -#include +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirGPUPasses, m) { +NB_MODULE(_mlirGPUPasses, m) { m.doc() = "MLIR GPU Dialect Passes"; // Register all GPU passes on load. diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 2db690309..a2df824f5 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -6,10 +6,6 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include - #include #include #include @@ -21,8 +17,9 @@ #include "NanobindUtils.h" #include "mlir-c/AffineExpr.h" #include "mlir-c/AffineMap.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/IntegerSet.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir/Support/LLVM.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 779af0950..08f7d4881 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -6,13 +6,6 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include -#include - #include #include #include @@ -24,6 +17,7 @@ #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index e1c56a398..86afa9563 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,26 +6,20 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include -#include -#include -#include - #include #include #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/lib/Bindings/Python/IRInterfaces.cpp b/mlir/lib/Bindings/Python/IRInterfaces.cpp index c339a93e3..9e1fedaab 100644 --- a/mlir/lib/Bindings/Python/IRInterfaces.cpp +++ b/mlir/lib/Bindings/Python/IRInterfaces.cpp @@ -6,10 +6,6 @@ // //===----------------------------------------------------------------------===// -#include -#include -#include - #include #include #include @@ -21,6 +17,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Interfaces.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 416a14218..f7bf77e5a 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -8,16 +8,14 @@ #include "IRModule.h" -#include -#include - #include #include #include "Globals.h" #include "NanobindUtils.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index a242ff26b..8fb32a225 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -10,9 +10,6 @@ #ifndef MLIR_BINDINGS_PYTHON_IRMODULES_H #define MLIR_BINDINGS_PYTHON_IRMODULES_H -#include -#include - #include #include #include @@ -26,6 +23,7 @@ #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/DenseMap.h" namespace mlir { diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 5cfa51142..0f2719c10 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -11,12 +11,6 @@ #include "mlir/Bindings/Python/IRTypes.h" // clang-format on -#include -#include -#include -#include -#include - #include #include "IRModule.h" diff --git a/mlir/lib/Bindings/Python/LinalgPasses.cpp b/mlir/lib/Bindings/Python/LinalgPasses.cpp index 3f230207a..49f2ea941 100644 --- a/mlir/lib/Bindings/Python/LinalgPasses.cpp +++ b/mlir/lib/Bindings/Python/LinalgPasses.cpp @@ -8,13 +8,13 @@ #include "mlir-c/Dialect/Linalg.h" -#include +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirLinalgPasses, m) { +NB_MODULE(_mlirLinalgPasses, m) { m.doc() = "MLIR Linalg Dialect Passes"; // Register all Linalg passes on load. diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index e5e64a921..7c4064262 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,14 +6,13 @@ // //===----------------------------------------------------------------------===// -#include -#include #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" #include "Pass.h" #include "Rewrite.h" +#include "mlir/Bindings/Python/Nanobind.h" namespace nb = nanobind; using namespace mlir; diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 3b0f7f698..ee193cf9f 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -10,9 +10,8 @@ #ifndef MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H -#include - #include "mlir-c/Support.h" +#include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" @@ -68,7 +67,7 @@ namespace detail { template struct MlirDefaultingCaster { - NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)); + NB_TYPE_CASTER(DefaultingTy, const_name(DefaultingTy::kTypeDescription)) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { if (src.is_none()) { diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index b5dce4fe4..858c3bd57 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -8,13 +8,10 @@ #include "Pass.h" -#include -#include -#include - #include "IRModule.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Pass.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace nb::literals; diff --git a/mlir/lib/Bindings/Python/RegisterEverything.cpp b/mlir/lib/Bindings/Python/RegisterEverything.cpp index 6b2f6b0a6..3ba42bec5 100644 --- a/mlir/lib/Bindings/Python/RegisterEverything.cpp +++ b/mlir/lib/Bindings/Python/RegisterEverything.cpp @@ -7,9 +7,10 @@ //===----------------------------------------------------------------------===// #include "mlir-c/RegisterEverything.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -PYBIND11_MODULE(_mlirRegisterEverything, m) { +NB_MODULE(_mlirRegisterEverything, m) { m.doc() = "MLIR All Upstream Dialects, Translations and Passes Registration"; m.def("register_dialects", [](MlirDialectRegistry registry) { diff --git a/mlir/lib/Bindings/Python/Rewrite.cpp b/mlir/lib/Bindings/Python/Rewrite.cpp index b2c1de4be..0373f9c7a 100644 --- a/mlir/lib/Bindings/Python/Rewrite.cpp +++ b/mlir/lib/Bindings/Python/Rewrite.cpp @@ -8,11 +8,10 @@ #include "Rewrite.h" -#include - #include "IRModule.h" -#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Rewrite.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir/Config/mlir-config.h" namespace nb = nanobind; diff --git a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp index 2a8e2b802..8242f0973 100644 --- a/mlir/lib/Bindings/Python/SparseTensorPasses.cpp +++ b/mlir/lib/Bindings/Python/SparseTensorPasses.cpp @@ -8,13 +8,13 @@ #include "mlir-c/Dialect/SparseTensor.h" -#include +#include "mlir/Bindings/Python/Nanobind.h" // ----------------------------------------------------------------------------- // Module initialization. // ----------------------------------------------------------------------------- -PYBIND11_MODULE(_mlirSparseTensorPasses, m) { +NB_MODULE(_mlirSparseTensorPasses, m) { m.doc() = "MLIR SparseTensor Dialect Passes"; // Register all SparseTensor passes on load. diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index 0c8c0e0a9..f9b0fed62 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -10,16 +10,14 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "mlir-c/Dialect/Transform/Interpreter.h" #include "mlir-c/IR.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Diagnostics.h" -#include "mlir/Bindings/Python/PybindAdaptors.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" -namespace py = pybind11; +namespace nb = nanobind; namespace { struct PyMlirTransformOptions { @@ -36,10 +34,10 @@ struct PyMlirTransformOptions { }; } // namespace -static void populateTransformInterpreterSubmodule(py::module &m) { - py::class_(m, "TransformOptions", py::module_local()) - .def(py::init()) - .def_property( +static void populateTransformInterpreterSubmodule(nb::module_ &m) { + nb::class_(m, "TransformOptions") + .def(nb::init<>()) + .def_prop_rw( "expensive_checks", [](const PyMlirTransformOptions &self) { return mlirTransformOptionsGetExpensiveChecksEnabled(self.options); @@ -47,7 +45,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) { [](PyMlirTransformOptions &self, bool value) { mlirTransformOptionsEnableExpensiveChecks(self.options, value); }) - .def_property( + .def_prop_rw( "enforce_single_top_level_transform_op", [](const PyMlirTransformOptions &self) { return mlirTransformOptionsGetEnforceSingleTopLevelTransformOp( @@ -68,7 +66,7 @@ static void populateTransformInterpreterSubmodule(py::module &m) { // Calling back into Python to invalidate everything under the payload // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. - py::object obj = py::cast(payloadRoot); + nb::object obj = nb::cast(payloadRoot); obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( @@ -76,13 +74,14 @@ static void populateTransformInterpreterSubmodule(py::module &m) { if (mlirLogicalResultIsSuccess(result)) return; - throw py::value_error( - "Failed to apply named transform sequence.\nDiagnostic message " + - scope.takeMessage()); + throw nb::value_error( + ("Failed to apply named transform sequence.\nDiagnostic message " + + scope.takeMessage()) + .c_str()); }, - py::arg("payload_root"), py::arg("transform_root"), - py::arg("transform_module"), - py::arg("transform_options") = PyMlirTransformOptions()); + nb::arg("payload_root"), nb::arg("transform_root"), + nb::arg("transform_module"), + nb::arg("transform_options") = PyMlirTransformOptions()); m.def( "copy_symbols_and_merge_into", @@ -92,15 +91,16 @@ static void populateTransformInterpreterSubmodule(py::module &m) { MlirLogicalResult result = mlirMergeSymbolsIntoFromClone(target, other); if (mlirLogicalResultIsFailure(result)) { - throw py::value_error( - "Failed to merge symbols.\nDiagnostic message " + - scope.takeMessage()); + throw nb::value_error( + ("Failed to merge symbols.\nDiagnostic message " + + scope.takeMessage()) + .c_str()); } }, - py::arg("target"), py::arg("other")); + nb::arg("target"), nb::arg("other")); } -PYBIND11_MODULE(_mlirTransformInterpreter, m) { +NB_MODULE(_mlirTransformInterpreter, m) { m.doc() = "MLIR Transform dialect interpreter functionality."; populateTransformInterpreterSubmodule(m); } diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 6d6b98312..fb115a5f4 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -476,9 +476,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core # Dialects MLIRCAPIFunc ) -if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL) -set_target_properties(MLIRPythonExtension.Core PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic") -endif() # This extension exposes an API to register all dialects, extensions, and passes # packaged in upstream MLIR and it is used for the upstream "mlir" Python @@ -490,6 +487,7 @@ endif() declare_mlir_python_extension(MLIRPythonExtension.RegisterEverything MODULE_NAME _mlirRegisterEverything ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES RegisterEverything.cpp PRIVATE_LINK_LIBS @@ -504,6 +502,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Linalg.Pybind MODULE_NAME _mlirDialectsLinalg ADD_TO_PARENT MLIRPythonSources.Dialects.linalg ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectLinalg.cpp PRIVATE_LINK_LIBS @@ -517,6 +516,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.GPU.Pybind MODULE_NAME _mlirDialectsGPU ADD_TO_PARENT MLIRPythonSources.Dialects.gpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectGPU.cpp PRIVATE_LINK_LIBS @@ -530,6 +530,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.LLVM.Pybind MODULE_NAME _mlirDialectsLLVM ADD_TO_PARENT MLIRPythonSources.Dialects.llvm ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectLLVM.cpp PRIVATE_LINK_LIBS @@ -543,6 +544,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Quant.Pybind MODULE_NAME _mlirDialectsQuant ADD_TO_PARENT MLIRPythonSources.Dialects.quant ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectQuant.cpp PRIVATE_LINK_LIBS @@ -556,6 +558,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.NVGPU.Pybind MODULE_NAME _mlirDialectsNVGPU ADD_TO_PARENT MLIRPythonSources.Dialects.nvgpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectNVGPU.cpp PRIVATE_LINK_LIBS @@ -569,6 +572,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.PDL.Pybind MODULE_NAME _mlirDialectsPDL ADD_TO_PARENT MLIRPythonSources.Dialects.pdl ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectPDL.cpp PRIVATE_LINK_LIBS @@ -582,6 +586,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SparseTensor.Pybind MODULE_NAME _mlirDialectsSparseTensor ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectSparseTensor.cpp PRIVATE_LINK_LIBS @@ -595,6 +600,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind MODULE_NAME _mlirDialectsTransform ADD_TO_PARENT MLIRPythonSources.Dialects.transform ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectTransform.cpp PRIVATE_LINK_LIBS @@ -608,6 +614,7 @@ declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses ADD_TO_PARENT MLIRPythonSources.Dialects.async ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES AsyncPasses.cpp PRIVATE_LINK_LIBS @@ -621,6 +628,7 @@ if(MLIR_ENABLE_EXECUTION_ENGINE) MODULE_NAME _mlirExecutionEngine ADD_TO_PARENT MLIRPythonSources.ExecutionEngine ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES ExecutionEngineModule.cpp PRIVATE_LINK_LIBS @@ -634,6 +642,7 @@ declare_mlir_python_extension(MLIRPythonExtension.GPUDialectPasses MODULE_NAME _mlirGPUPasses ADD_TO_PARENT MLIRPythonSources.Dialects.gpu ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES GPUPasses.cpp PRIVATE_LINK_LIBS @@ -646,6 +655,7 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MODULE_NAME _mlirLinalgPasses ADD_TO_PARENT MLIRPythonSources.Dialects.linalg ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES LinalgPasses.cpp PRIVATE_LINK_LIBS @@ -658,6 +668,7 @@ declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES SparseTensorPasses.cpp PRIVATE_LINK_LIBS @@ -670,6 +681,7 @@ declare_mlir_python_extension(MLIRPythonExtension.TransformInterpreter MODULE_NAME _mlirTransformInterpreter ADD_TO_PARENT MLIRPythonSources.Dialects.transform ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind SOURCES TransformInterpreter.cpp PRIVATE_LINK_LIBS @@ -735,9 +747,6 @@ if(MLIR_INCLUDE_TESTS) EMBED_CAPI_LINK_LIBS MLIRCAPIPythonTestDialect ) - if (LLVM_COMPILER_IS_GCC_COMPATIBLE OR CLANG_CL) - set_target_properties(MLIRPythonTestSources.PythonTestExtensionNanobind PROPERTIES INTERFACE_COMPILE_OPTIONS "-Wno-cast-qual;-Wno-zero-length-array;-Wno-extra-semi;-Wno-nested-anon-types;-Wno-pedantic") - endif() endif() ################################################################################ @@ -794,3 +803,4 @@ add_mlir_python_modules(MLIRPythonModules COMMON_CAPI_LINK_LIBS MLIRPythonCAPI ) + From e1c5e22466c022377fdcad5bdf2398d6072d9613 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Sun, 29 Dec 2024 09:13:10 -0800 Subject: [PATCH 813/915] [mlir][python] disable nanobind leak warnings (#121099) --- mlir/lib/Bindings/Python/IRCore.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 86afa9563..05c000bfd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2587,6 +2587,8 @@ class PyOpAttributeMap { //------------------------------------------------------------------------------ void mlir::python::populateIRCore(nb::module_ &m) { + // disable leak warnings which tend to be false positives. + nb::set_leak_warnings(false); //---------------------------------------------------------------------------- // Enums. //---------------------------------------------------------------------------- From c5e431b5099ff350ff9e7937257774ef030f876a Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Thu, 2 Jan 2025 14:40:15 -0800 Subject: [PATCH 814/915] [mlir][py] Enable loading only specified dialects during creation. (#121421) Gives option post as global list as well as arg to control which dialects are loaded during context creation. This enables setting either a good base set or skipping in individual cases. --- mlir/python/mlir/_mlir_libs/__init__.py | 42 +++++++++++++++++++++++-- mlir/python/mlir/ir.py | 6 +++- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index c5cb22c6d..d021dde05 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -58,6 +58,7 @@ def get_include_dirs() -> Sequence[str]: # needs. _dialect_registry = None +_load_on_create_dialects = None def get_dialect_registry(): @@ -71,6 +72,21 @@ def get_dialect_registry(): return _dialect_registry +def append_load_on_create_dialect(dialect: str): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [dialect] + else: + _load_on_create_dialects.append(dialect) + + +def get_load_on_create_dialects(): + global _load_on_create_dialects + if _load_on_create_dialects is None: + _load_on_create_dialects = [] + return _load_on_create_dialects + + def _site_initialize(): import importlib import itertools @@ -132,15 +148,35 @@ def process_initializer_module(module_name): break class Context(ir._BaseContext): - def __init__(self, *args, **kwargs): + def __init__(self, load_on_create_dialects=None, *args, **kwargs): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) if not disable_multithreading: self.enable_multithreading(True) - if not disable_load_all_available_dialects: - self.load_all_available_dialects() + if load_on_create_dialects is not None: + logger.debug( + "Loading all dialects from load_on_create_dialects arg %r", + load_on_create_dialects, + ) + for dialect in load_on_create_dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + if disable_load_all_available_dialects: + dialects = get_load_on_create_dialects() + if dialects: + logger.debug( + "Loading all dialects from global load_on_create_dialects %r", + dialects, + ) + for dialect in dialects: + # This triggers loading the dialect into the context. + _ = self.dialects[dialect] + else: + logger.debug("Loading all available dialects") + self.load_all_available_dialects() if init_module: logger.debug( "Registering translations from initializer %r", init_module diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 9a6ce4620..6f37266d5 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,7 +5,11 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster, register_value_caster -from ._mlir_libs import get_dialect_registry +from ._mlir_libs import ( + get_dialect_registry, + append_load_on_create_dialect, + get_load_on_create_dialects, +) # Convenience decorator for registering user-friendly Attribute builders. From f0c6850b0f2a4264615a1c6ae01a606108f55577 Mon Sep 17 00:00:00 2001 From: Hugo Trachino Date: Fri, 3 Jan 2025 11:21:59 +0000 Subject: [PATCH 815/915] [MLIR][Python] Add structured.fuseop to python interpreter (#120601) Implements a python interface for structured.fuseOp allowing more freedom with inputs. --- .../mlir/dialects/transform/structured.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/mlir/python/mlir/dialects/transform/structured.py b/mlir/python/mlir/dialects/transform/structured.py index 9121aa8e4..bf40cc532 100644 --- a/mlir/python/mlir/dialects/transform/structured.py +++ b/mlir/python/mlir/dialects/transform/structured.py @@ -140,6 +140,77 @@ def __init__( ) +@_ods_cext.register_operation(_Dialect, replace=True) +class FuseOp(FuseOp): + """Specialization for FuseOp class.""" + + @overload + def __init__( + self, + loop_types: Union[Type, Sequence[Type]], + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + @overload + def __init__( + self, + target: Union[Operation, Value, OpView], + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + ... + + def __init__( + self, + loop_types_or_target: Union[Type, Sequence[Type], Operation, OpView, Value], + target_or_none: Optional[Union[Operation, Value, OpView]] = None, + *, + tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None, + tile_interchange: OptionalIntList = None, + apply_cleanup: Optional[bool] = False, + loc=None, + ip=None, + ): + tile_sizes = tile_sizes if tile_sizes else [] + tile_interchange = tile_interchange if tile_interchange else [] + _, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes) + _, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange) + num_loops = sum(0 if v == 0 else 1 for v in tile_sizes) + + if isinstance(loop_types_or_target, (Operation, Value, OpView)): + loop_types = [transform.AnyOpType.get()] * num_loops + target = loop_types_or_target + assert target_or_none is None, "Cannot construct FuseOp with two targets." + else: + loop_types = ( + ([loop_types_or_target] * num_loops) + if isinstance(loop_types_or_target, Type) + else loop_types_or_target + ) + target = target_or_none + super().__init__( + target.type, + loop_types, + target, + tile_sizes=tile_sizes, + tile_interchange=tile_interchange, + apply_cleanup=apply_cleanup, + loc=loc, + ip=ip, + ) + + @_ods_cext.register_operation(_Dialect, replace=True) class GeneralizeOp(GeneralizeOp): """Specialization for GeneralizeOp class.""" From 3f59753ab34855d02021fccdcc98bddca896d04a Mon Sep 17 00:00:00 2001 From: William Moses Date: Tue, 7 Jan 2025 02:39:44 -0500 Subject: [PATCH 816/915] [MLIR][CAPI] export LLVMFunctionType param getter and setters (#121888) --- mlir/include/mlir-c/Dialect/LLVM.h | 7 +++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 0992285f9..26c414075 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -45,6 +45,13 @@ MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, MlirType const *argumentTypes, bool isVarArg); +/// Returns the number of input types. +MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); + +/// Returns the pos-th input type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, + intptr_t pos); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 6ed82ba1a..da450dd3f 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -55,6 +55,16 @@ MlirType mlirLLVMFunctionTypeGet(MlirType resultType, intptr_t nArgumentTypes, unwrapList(nArgumentTypes, argumentTypes, argumentStorage), isVarArg)); } +intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type) { + return llvm::cast(unwrap(type)).getNumParams(); +} + +MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { + assert(pos >= 0 && "pos in array must be positive"); + return wrap(llvm::cast(unwrap(type)) + .getParamType(static_cast(pos))); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } From d4c5ae7b434261e8cb99b9a4d1addc6e5206d5fa Mon Sep 17 00:00:00 2001 From: vfdev Date: Tue, 7 Jan 2025 16:33:01 +0100 Subject: [PATCH 817/915] Fixed typo in dunder get/set methods in PyAttrBuilderMap (#121794) Description: - fixed a typo in the method name: dunde -> dunder --- mlir/lib/Bindings/Python/IRCore.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 05c000bfd..453d4f7c7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -272,13 +272,13 @@ struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); } - static nb::callable dundeGetItemNamed(const std::string &attributeKind) { + static nb::callable dunderGetItemNamed(const std::string &attributeKind) { auto builder = PyGlobals::get().lookupAttributeBuilder(attributeKind); if (!builder) throw nb::key_error(attributeKind.c_str()); return *builder; } - static void dundeSetItemNamed(const std::string &attributeKind, + static void dunderSetItemNamed(const std::string &attributeKind, nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); @@ -287,8 +287,8 @@ struct PyAttrBuilderMap { static void bind(nb::module_ &m) { nb::class_(m, "AttrBuilder") .def_static("contains", &PyAttrBuilderMap::dunderContains) - .def_static("get", &PyAttrBuilderMap::dundeGetItemNamed) - .def_static("insert", &PyAttrBuilderMap::dundeSetItemNamed, + .def_static("get", &PyAttrBuilderMap::dunderGetItemNamed) + .def_static("insert", &PyAttrBuilderMap::dunderSetItemNamed, "attribute_kind"_a, "attr_builder"_a, "replace"_a = false, "Register an attribute builder for building MLIR " "attributes from python values."); From d4841a337ecf58bf584b5ae01338998a3434f906 Mon Sep 17 00:00:00 2001 From: vfdev Date: Sun, 12 Jan 2025 18:56:49 +0100 Subject: [PATCH 818/915] Added free-threading CPython mode support in MLIR Python bindings (#107103) Related to https://github.com/llvm/llvm-project/issues/105522 Description: This PR is a joint work with Peter Hawkins (@hawkinsp) originally done by myself for pybind11 and then reworked to nanobind based on Peter's branch: https://github.com/hawkinsp/llvm-project/tree/nbdev . - Added free-threading CPython mode support for MLIR Python bindings - Added a test which can reveal data races when cpython and LLVM/MLIR compiled with TSAN Context: - Related to https://github.com/google/jax/issues/23073 Co-authored-by: Peter Hawkins --- mlir/lib/Bindings/Python/Globals.h | 12 +++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 31 +++++++++++++++++++++---- mlir/lib/Bindings/Python/IRModule.cpp | 18 ++++++++++++-- mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/lib/Bindings/Python/MainModule.cpp | 9 ++----- mlir/python/requirements.txt | 2 +- 6 files changed, 58 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 0ec522d14..826a34a53 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -24,6 +24,7 @@ namespace mlir { namespace python { /// Globals that are always accessible once the extension has been initialized. +/// Methods of this class are thread-safe. class PyGlobals { public: PyGlobals(); @@ -37,12 +38,18 @@ class PyGlobals { /// Get and set the list of parent modules to search for dialect /// implementation classes. - std::vector &getDialectSearchPrefixes() { + std::vector getDialectSearchPrefixes() { + nanobind::ft_lock_guard lock(mutex); return dialectSearchPrefixes; } void setDialectSearchPrefixes(std::vector newValues) { + nanobind::ft_lock_guard lock(mutex); dialectSearchPrefixes.swap(newValues); } + void addDialectSearchPrefix(std::string value) { + nanobind::ft_lock_guard lock(mutex); + dialectSearchPrefixes.push_back(std::move(value)); + } /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises @@ -109,6 +116,9 @@ class PyGlobals { private: static PyGlobals *instance; + + nanobind::ft_mutex mutex; + /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 453d4f7c7..463ebdebb 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { + nb::ft_lock_guard lock(mutex); + mlirEnableGlobalDebug(enable); + } - static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { + nb::ft_lock_guard lock(mutex); + return mlirIsGlobalDebugEnabled(); + } static void bind(nb::module_ &m) { // Debug flags. @@ -255,6 +261,7 @@ struct PyGlobalDebugFlag { .def_static( "set_types", [](const std::string &type) { + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, "types"_a, "Sets specific debug types to be produced by LLVM") @@ -263,11 +270,17 @@ struct PyGlobalDebugFlag { pointers.reserve(types.size()); for (const std::string &str : types) pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); }); } + +private: + static nb::ft_mutex mutex; }; +nb::ft_mutex PyGlobalDebugFlag::mutex; + struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); @@ -606,6 +619,7 @@ class PyOpOperandIterator { PyMlirContext::PyMlirContext(MlirContext context) : context(context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() { // forContext method, which always puts the associated handle into // liveContexts. nb::gil_scoped_acquire acquire; - getLiveContexts().erase(context.ptr); + { + nb::ft_lock_guard lock(live_contexts_mutex); + getLiveContexts().erase(context.ptr); + } mlirContextDestroy(context); } @@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) { PyMlirContextRef PyMlirContext::forContext(MlirContext context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { @@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { return PyMlirContextRef(it->second, std::move(pyRef)); } +nb::ft_mutex PyMlirContext::live_contexts_mutex; + PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; } -size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } +size_t PyMlirContext::getLiveCount() { + nb::ft_lock_guard lock(live_contexts_mutex); + return getLiveContexts().size(); +} size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index f7bf77e5a..e600f1bbd 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -38,8 +38,11 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - if (loadedDialectModules.contains(dialectNamespace)) - return true; + { + nb::ft_lock_guard lock(mutex); + if (loadedDialectModules.contains(dialectNamespace)) + return true; + } // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; nb::object loaded = nb::none(); @@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. + nb::ft_lock_guard lock(mutex); loadedDialectModules.insert(dialectNamespace); return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, nb::callable pyFunc, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + @@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, nb::callable typeCaster, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + @@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, nb::callable valueCaster, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + @@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, nb::object pyClass) { + nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + @@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, void PyGlobals::registerOperationImpl(const std::string &operationName, nb::object pyClass, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + @@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + nb::ft_lock_guard lock(mutex); const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { assert(foundIt->second && "attribute builder is defined"); @@ -133,6 +143,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); @@ -145,6 +156,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); @@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) return std::nullopt; + nb::ft_lock_guard lock(mutex); const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { assert(foundIt->second && "dialect class is defined"); @@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { if (!loadDialectModule(dialectNamespace)) return std::nullopt; + nb::ft_lock_guard lock(mutex); auto foundIt = operationClassMap.find(operationName); if (foundIt != operationClassMap.end()) { assert(foundIt->second && "OpView is defined"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 8fb32a225..f5fbb6c61 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -260,6 +260,7 @@ class PyMlirContext { // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. using LiveContextMap = llvm::DenseMap; + static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); // Interns all live modules associated with this context. Modules tracked diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c4064262..6f4943100 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) { .def_prop_rw("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) - .def( - "append_dialect_search_prefix", - [](PyGlobals &self, std::string moduleName) { - self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - }, - "module_name"_a) + .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, + "module_name"_a) .def( "_check_dialect_module_loaded", [](PyGlobals &self, const std::string &dialectNamespace) { @@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) { nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); - // Dict-stuff the new opClass by name onto the dialect class. nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index f240d6ef9..259e679f5 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -2,4 +2,4 @@ nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 -ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16 From 7e25a9ae85123e9b7f1b54bee2ad29891322ee92 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Sun, 12 Jan 2025 18:30:42 +0000 Subject: [PATCH 819/915] Revert "Added free-threading CPython mode support in MLIR Python bindings (#107103)" Breaks on 3.8, rolling back to avoid breakage while fixing. This reverts commit d4841a337ecf58bf584b5ae01338998a3434f906. --- mlir/lib/Bindings/Python/Globals.h | 12 +--------- mlir/lib/Bindings/Python/IRCore.cpp | 31 ++++--------------------- mlir/lib/Bindings/Python/IRModule.cpp | 18 ++------------ mlir/lib/Bindings/Python/IRModule.h | 1 - mlir/lib/Bindings/Python/MainModule.cpp | 9 +++++-- mlir/python/requirements.txt | 2 +- 6 files changed, 15 insertions(+), 58 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 826a34a53..0ec522d14 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -24,7 +24,6 @@ namespace mlir { namespace python { /// Globals that are always accessible once the extension has been initialized. -/// Methods of this class are thread-safe. class PyGlobals { public: PyGlobals(); @@ -38,18 +37,12 @@ class PyGlobals { /// Get and set the list of parent modules to search for dialect /// implementation classes. - std::vector getDialectSearchPrefixes() { - nanobind::ft_lock_guard lock(mutex); + std::vector &getDialectSearchPrefixes() { return dialectSearchPrefixes; } void setDialectSearchPrefixes(std::vector newValues) { - nanobind::ft_lock_guard lock(mutex); dialectSearchPrefixes.swap(newValues); } - void addDialectSearchPrefix(std::string value) { - nanobind::ft_lock_guard lock(mutex); - dialectSearchPrefixes.push_back(std::move(value)); - } /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises @@ -116,9 +109,6 @@ class PyGlobals { private: static PyGlobals *instance; - - nanobind::ft_mutex mutex; - /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 463ebdebb..453d4f7c7 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -243,15 +243,9 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { - nb::ft_lock_guard lock(mutex); - mlirEnableGlobalDebug(enable); - } + static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } - static bool get(const nb::object &) { - nb::ft_lock_guard lock(mutex); - return mlirIsGlobalDebugEnabled(); - } + static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } static void bind(nb::module_ &m) { // Debug flags. @@ -261,7 +255,6 @@ struct PyGlobalDebugFlag { .def_static( "set_types", [](const std::string &type) { - nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, "types"_a, "Sets specific debug types to be produced by LLVM") @@ -270,17 +263,11 @@ struct PyGlobalDebugFlag { pointers.reserve(types.size()); for (const std::string &str : types) pointers.push_back(str.c_str()); - nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); }); } - -private: - static nb::ft_mutex mutex; }; -nb::ft_mutex PyGlobalDebugFlag::mutex; - struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); @@ -619,7 +606,6 @@ class PyOpOperandIterator { PyMlirContext::PyMlirContext(MlirContext context) : context(context) { nb::gil_scoped_acquire acquire; - nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -629,10 +615,7 @@ PyMlirContext::~PyMlirContext() { // forContext method, which always puts the associated handle into // liveContexts. nb::gil_scoped_acquire acquire; - { - nb::ft_lock_guard lock(live_contexts_mutex); - getLiveContexts().erase(context.ptr); - } + getLiveContexts().erase(context.ptr); mlirContextDestroy(context); } @@ -649,7 +632,6 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) { PyMlirContextRef PyMlirContext::forContext(MlirContext context) { nb::gil_scoped_acquire acquire; - nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { @@ -665,17 +647,12 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { return PyMlirContextRef(it->second, std::move(pyRef)); } -nb::ft_mutex PyMlirContext::live_contexts_mutex; - PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; } -size_t PyMlirContext::getLiveCount() { - nb::ft_lock_guard lock(live_contexts_mutex); - return getLiveContexts().size(); -} +size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e600f1bbd..f7bf77e5a 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -38,11 +38,8 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - { - nb::ft_lock_guard lock(mutex); - if (loadedDialectModules.contains(dialectNamespace)) - return true; - } + if (loadedDialectModules.contains(dialectNamespace)) + return true; // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; nb::object loaded = nb::none(); @@ -65,14 +62,12 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. - nb::ft_lock_guard lock(mutex); loadedDialectModules.insert(dialectNamespace); return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, nb::callable pyFunc, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + @@ -86,7 +81,6 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, nb::callable typeCaster, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + @@ -96,7 +90,6 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, nb::callable valueCaster, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + @@ -106,7 +99,6 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, nb::object pyClass) { - nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + @@ -118,7 +110,6 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, void PyGlobals::registerOperationImpl(const std::string &operationName, nb::object pyClass, bool replace) { - nb::ft_lock_guard lock(mutex); nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + @@ -130,7 +121,6 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { - nb::ft_lock_guard lock(mutex); const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { assert(foundIt->second && "attribute builder is defined"); @@ -143,7 +133,6 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - nb::ft_lock_guard lock(mutex); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); @@ -156,7 +145,6 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - nb::ft_lock_guard lock(mutex); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); @@ -170,7 +158,6 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) return std::nullopt; - nb::ft_lock_guard lock(mutex); const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { assert(foundIt->second && "dialect class is defined"); @@ -188,7 +175,6 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { if (!loadDialectModule(dialectNamespace)) return std::nullopt; - nb::ft_lock_guard lock(mutex); auto foundIt = operationClassMap.find(operationName); if (foundIt != operationClassMap.end()) { assert(foundIt->second && "OpView is defined"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f5fbb6c61..8fb32a225 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -260,7 +260,6 @@ class PyMlirContext { // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. using LiveContextMap = llvm::DenseMap; - static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); // Interns all live modules associated with this context. Modules tracked diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f4943100..7c4064262 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,8 +30,12 @@ NB_MODULE(_mlir, m) { .def_prop_rw("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) - .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, - "module_name"_a) + .def( + "append_dialect_search_prefix", + [](PyGlobals &self, std::string moduleName) { + self.getDialectSearchPrefixes().push_back(std::move(moduleName)); + }, + "module_name"_a) .def( "_check_dialect_module_loaded", [](PyGlobals &self, const std::string &dialectNamespace) { @@ -72,6 +76,7 @@ NB_MODULE(_mlir, m) { nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); + // Dict-stuff the new opClass by name onto the dialect class. nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index 259e679f5..f240d6ef9 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -2,4 +2,4 @@ nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 -ml_dtypes>=0.5.0, <=0.6.0 # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 From 45f52911ac5fa052e3c1b84dce6854a3837f960e Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 13 Jan 2025 12:00:31 +0100 Subject: [PATCH 820/915] Enabled freethreading support in MLIR python bindings (#122684) Reland reverted https://github.com/llvm/llvm-project/pull/107103 with the fixes for Python 3.8 cc @jpienaar Co-authored-by: Peter Hawkins --- mlir/lib/Bindings/Python/Globals.h | 12 +++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 31 +++++++++++++++++++++---- mlir/lib/Bindings/Python/IRModule.cpp | 18 ++++++++++++-- mlir/lib/Bindings/Python/IRModule.h | 1 + mlir/lib/Bindings/Python/MainModule.cpp | 9 ++----- mlir/python/requirements.txt | 3 ++- 6 files changed, 59 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 0ec522d14..826a34a53 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -24,6 +24,7 @@ namespace mlir { namespace python { /// Globals that are always accessible once the extension has been initialized. +/// Methods of this class are thread-safe. class PyGlobals { public: PyGlobals(); @@ -37,12 +38,18 @@ class PyGlobals { /// Get and set the list of parent modules to search for dialect /// implementation classes. - std::vector &getDialectSearchPrefixes() { + std::vector getDialectSearchPrefixes() { + nanobind::ft_lock_guard lock(mutex); return dialectSearchPrefixes; } void setDialectSearchPrefixes(std::vector newValues) { + nanobind::ft_lock_guard lock(mutex); dialectSearchPrefixes.swap(newValues); } + void addDialectSearchPrefix(std::string value) { + nanobind::ft_lock_guard lock(mutex); + dialectSearchPrefixes.push_back(std::move(value)); + } /// Loads a python module corresponding to the given dialect namespace. /// No-ops if the module has already been loaded or is not found. Raises @@ -109,6 +116,9 @@ class PyGlobals { private: static PyGlobals *instance; + + nanobind::ft_mutex mutex; + /// Module name prefixes to search under for dialect implementation modules. std::vector dialectSearchPrefixes; /// Map of dialect namespace to external dialect class object. diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 453d4f7c7..463ebdebb 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -243,9 +243,15 @@ static MlirBlock createBlock(const nb::sequence &pyArgTypes, /// Wrapper for the global LLVM debugging flag. struct PyGlobalDebugFlag { - static void set(nb::object &o, bool enable) { mlirEnableGlobalDebug(enable); } + static void set(nb::object &o, bool enable) { + nb::ft_lock_guard lock(mutex); + mlirEnableGlobalDebug(enable); + } - static bool get(const nb::object &) { return mlirIsGlobalDebugEnabled(); } + static bool get(const nb::object &) { + nb::ft_lock_guard lock(mutex); + return mlirIsGlobalDebugEnabled(); + } static void bind(nb::module_ &m) { // Debug flags. @@ -255,6 +261,7 @@ struct PyGlobalDebugFlag { .def_static( "set_types", [](const std::string &type) { + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugType(type.c_str()); }, "types"_a, "Sets specific debug types to be produced by LLVM") @@ -263,11 +270,17 @@ struct PyGlobalDebugFlag { pointers.reserve(types.size()); for (const std::string &str : types) pointers.push_back(str.c_str()); + nb::ft_lock_guard lock(mutex); mlirSetGlobalDebugTypes(pointers.data(), pointers.size()); }); } + +private: + static nb::ft_mutex mutex; }; +nb::ft_mutex PyGlobalDebugFlag::mutex; + struct PyAttrBuilderMap { static bool dunderContains(const std::string &attributeKind) { return PyGlobals::get().lookupAttributeBuilder(attributeKind).has_value(); @@ -606,6 +619,7 @@ class PyOpOperandIterator { PyMlirContext::PyMlirContext(MlirContext context) : context(context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); liveContexts[context.ptr] = this; } @@ -615,7 +629,10 @@ PyMlirContext::~PyMlirContext() { // forContext method, which always puts the associated handle into // liveContexts. nb::gil_scoped_acquire acquire; - getLiveContexts().erase(context.ptr); + { + nb::ft_lock_guard lock(live_contexts_mutex); + getLiveContexts().erase(context.ptr); + } mlirContextDestroy(context); } @@ -632,6 +649,7 @@ nb::object PyMlirContext::createFromCapsule(nb::object capsule) { PyMlirContextRef PyMlirContext::forContext(MlirContext context) { nb::gil_scoped_acquire acquire; + nb::ft_lock_guard lock(live_contexts_mutex); auto &liveContexts = getLiveContexts(); auto it = liveContexts.find(context.ptr); if (it == liveContexts.end()) { @@ -647,12 +665,17 @@ PyMlirContextRef PyMlirContext::forContext(MlirContext context) { return PyMlirContextRef(it->second, std::move(pyRef)); } +nb::ft_mutex PyMlirContext::live_contexts_mutex; + PyMlirContext::LiveContextMap &PyMlirContext::getLiveContexts() { static LiveContextMap liveContexts; return liveContexts; } -size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } +size_t PyMlirContext::getLiveCount() { + nb::ft_lock_guard lock(live_contexts_mutex); + return getLiveContexts().size(); +} size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index f7bf77e5a..e600f1bbd 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -38,8 +38,11 @@ PyGlobals::PyGlobals() { PyGlobals::~PyGlobals() { instance = nullptr; } bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { - if (loadedDialectModules.contains(dialectNamespace)) - return true; + { + nb::ft_lock_guard lock(mutex); + if (loadedDialectModules.contains(dialectNamespace)) + return true; + } // Since re-entrancy is possible, make a copy of the search prefixes. std::vector localSearchPrefixes = dialectSearchPrefixes; nb::object loaded = nb::none(); @@ -62,12 +65,14 @@ bool PyGlobals::loadDialectModule(llvm::StringRef dialectNamespace) { return false; // Note: Iterator cannot be shared from prior to loading, since re-entrancy // may have occurred, which may do anything. + nb::ft_lock_guard lock(mutex); loadedDialectModules.insert(dialectNamespace); return true; } void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, nb::callable pyFunc, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = attributeBuilderMap[attributeKind]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Attribute builder for '") + @@ -81,6 +86,7 @@ void PyGlobals::registerAttributeBuilder(const std::string &attributeKind, void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, nb::callable typeCaster, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = typeCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Type caster is already registered with caster: " + @@ -90,6 +96,7 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, nb::callable valueCaster, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = valueCasterMap[mlirTypeID]; if (found && !replace) throw std::runtime_error("Value caster is already registered: " + @@ -99,6 +106,7 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, nb::object pyClass) { + nb::ft_lock_guard lock(mutex); nb::object &found = dialectClassMap[dialectNamespace]; if (found) { throw std::runtime_error((llvm::Twine("Dialect namespace '") + @@ -110,6 +118,7 @@ void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, void PyGlobals::registerOperationImpl(const std::string &operationName, nb::object pyClass, bool replace) { + nb::ft_lock_guard lock(mutex); nb::object &found = operationClassMap[operationName]; if (found && !replace) { throw std::runtime_error((llvm::Twine("Operation '") + operationName + @@ -121,6 +130,7 @@ void PyGlobals::registerOperationImpl(const std::string &operationName, std::optional PyGlobals::lookupAttributeBuilder(const std::string &attributeKind) { + nb::ft_lock_guard lock(mutex); const auto foundIt = attributeBuilderMap.find(attributeKind); if (foundIt != attributeBuilderMap.end()) { assert(foundIt->second && "attribute builder is defined"); @@ -133,6 +143,7 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); const auto foundIt = typeCasterMap.find(mlirTypeID); if (foundIt != typeCasterMap.end()) { assert(foundIt->second && "type caster is defined"); @@ -145,6 +156,7 @@ std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { // Try to load dialect module. (void)loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + nb::ft_lock_guard lock(mutex); const auto foundIt = valueCasterMap.find(mlirTypeID); if (foundIt != valueCasterMap.end()) { assert(foundIt->second && "value caster is defined"); @@ -158,6 +170,7 @@ PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. if (!loadDialectModule(dialectNamespace)) return std::nullopt; + nb::ft_lock_guard lock(mutex); const auto foundIt = dialectClassMap.find(dialectNamespace); if (foundIt != dialectClassMap.end()) { assert(foundIt->second && "dialect class is defined"); @@ -175,6 +188,7 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { if (!loadDialectModule(dialectNamespace)) return std::nullopt; + nb::ft_lock_guard lock(mutex); auto foundIt = operationClassMap.find(operationName); if (foundIt != operationClassMap.end()) { assert(foundIt->second && "OpView is defined"); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 8fb32a225..f5fbb6c61 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -260,6 +260,7 @@ class PyMlirContext { // Note that this holds a handle, which does not imply ownership. // Mappings will be removed when the context is destructed. using LiveContextMap = llvm::DenseMap; + static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); // Interns all live modules associated with this context. Modules tracked diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 7c4064262..6f4943100 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -30,12 +30,8 @@ NB_MODULE(_mlir, m) { .def_prop_rw("dialect_search_modules", &PyGlobals::getDialectSearchPrefixes, &PyGlobals::setDialectSearchPrefixes) - .def( - "append_dialect_search_prefix", - [](PyGlobals &self, std::string moduleName) { - self.getDialectSearchPrefixes().push_back(std::move(moduleName)); - }, - "module_name"_a) + .def("append_dialect_search_prefix", &PyGlobals::addDialectSearchPrefix, + "module_name"_a) .def( "_check_dialect_module_loaded", [](PyGlobals &self, const std::string &dialectNamespace) { @@ -76,7 +72,6 @@ NB_MODULE(_mlir, m) { nanobind::cast(opClass.attr("OPERATION_NAME")); PyGlobals::get().registerOperationImpl(operationName, opClass, replace); - // Dict-stuff the new opClass by name onto the dialect class. nb::object opClassName = opClass.attr("__name__"); dialectClass.attr(opClassName) = opClass; diff --git a/mlir/python/requirements.txt b/mlir/python/requirements.txt index f240d6ef9..1a0075e82 100644 --- a/mlir/python/requirements.txt +++ b/mlir/python/requirements.txt @@ -2,4 +2,5 @@ nanobind>=2.4, <3.0 numpy>=1.19.5, <=2.1.2 pybind11>=2.10.0, <=2.13.6 PyYAML>=5.4.0, <=6.0.1 -ml_dtypes>=0.1.0, <=0.5.0 # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.1.0, <=0.6.0; python_version<"3.13" # provides several NumPy dtype extensions, including the bf16 +ml_dtypes>=0.5.0, <=0.6.0; python_version>="3.13" \ No newline at end of file From 0698da65e4dd2271065c2687d52e87cabe5a3549 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 13 Jan 2025 10:49:25 -0500 Subject: [PATCH 821/915] [mlir python] Add locking around PyMlirContext::liveOperations. (#122720) In JAX, I observed a race between two PyOperation destructors from different threads updating the same `liveOperations` map, despite not intentionally sharing the context between different threads. Since I don't think we can be completely sure when GC happens and on which thread, it seems safest simply to add locking here. We may also want to explicitly support sharing a context between threads in the future, which would require this change or something similar. --- mlir/lib/Bindings/Python/IRCore.cpp | 43 +++++++++++++++++++++-------- mlir/lib/Bindings/Python/IRModule.h | 3 ++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 463ebdebb..53806ca9f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -677,29 +677,44 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { return liveOperations.size(); } +size_t PyMlirContext::getLiveOperationCount() { + nb::ft_lock_guard lock(liveOperationsMutex); + return liveOperations.size(); +} std::vector PyMlirContext::getLiveOperationObjects() { std::vector liveObjects; + nb::ft_lock_guard lock(liveOperationsMutex); for (auto &entry : liveOperations) liveObjects.push_back(entry.second.second); return liveObjects; } size_t PyMlirContext::clearLiveOperations() { - for (auto &op : liveOperations) + + LiveOperationMap operations; + { + nb::ft_lock_guard lock(liveOperationsMutex); + std::swap(operations, liveOperations); + } + for (auto &op : operations) op.second.second->setInvalid(); - size_t numInvalidated = liveOperations.size(); - liveOperations.clear(); + size_t numInvalidated = operations.size(); return numInvalidated; } void PyMlirContext::clearOperation(MlirOperation op) { - auto it = liveOperations.find(op.ptr); - if (it != liveOperations.end()) { - it->second.second->setInvalid(); + PyOperation *py_op; + { + nb::ft_lock_guard lock(liveOperationsMutex); + auto it = liveOperations.find(op.ptr); + if (it == liveOperations.end()) { + return; + } + py_op = it->second.second; liveOperations.erase(it); } + py_op->setInvalid(); } void PyMlirContext::clearOperationsInside(PyOperationBase &op) { @@ -1183,7 +1198,6 @@ PyOperation::~PyOperation() { PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - auto &liveOperations = contextRef->liveOperations; // Create. PyOperation *unownedOperation = new PyOperation(std::move(contextRef), operation); @@ -1195,19 +1209,22 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); } - liveOperations[operation.ptr] = std::make_pair(pyRef, unownedOperation); return PyOperationRef(unownedOperation, std::move(pyRef)); } PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { + nb::ft_lock_guard lock(contextRef->liveOperationsMutex); auto &liveOperations = contextRef->liveOperations; auto it = liveOperations.find(operation.ptr); if (it == liveOperations.end()) { // Create. - return createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); + PyOperationRef result = createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); + liveOperations[operation.ptr] = + std::make_pair(result.getObject(), result.get()); + return result; } // Use existing. PyOperation *existing = it->second.second; @@ -1218,13 +1235,15 @@ PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { + nb::ft_lock_guard lock(contextRef->liveOperationsMutex); auto &liveOperations = contextRef->liveOperations; assert(liveOperations.count(operation.ptr) == 0 && "cannot create detached operation that already exists"); (void)liveOperations; - PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); + liveOperations[operation.ptr] = + std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index f5fbb6c61..d1fb4308d 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -277,6 +277,9 @@ class PyMlirContext { // attempt to access it will raise an error. using LiveOperationMap = llvm::DenseMap>; + nanobind::ft_mutex liveOperationsMutex; + + // Guarded by liveOperationsMutex in free-threading mode. LiveOperationMap liveOperations; bool emitErrorDiagnostics = false; From e2a041495e557a46e4448424c3ab5b683fb6a1d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eliud=20de=20Le=C3=B3n?= Date: Tue, 14 Jan 2025 12:46:06 -0800 Subject: [PATCH 822/915] [mlir][emitc] Expose emitc dialect types (#119645) Added C API functions for the EmitC dialect types. --- mlir/include/mlir-c/Dialect/EmitC.h | 111 ++++++++++++++++++ mlir/lib/CAPI/Dialect/EmitC.cpp | 176 ++++++++++++++++++++++++++++ 2 files changed, 287 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/EmitC.h b/mlir/include/mlir-c/Dialect/EmitC.h index 82e698344..a0e3ea08a 100644 --- a/mlir/include/mlir-c/Dialect/EmitC.h +++ b/mlir/include/mlir-c/Dialect/EmitC.h @@ -19,6 +19,117 @@ extern "C" { MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(EmitC, emitc); +enum MlirEmitCCmpPredicate : uint64_t { + MLIR_EMITC_CMP_PREDICATE_EQ = 0, + MLIR_EMITC_CMP_PREDICATE_NE = 1, + MLIR_EMITC_CMP_PREDICATE_LT = 2, + MLIR_EMITC_CMP_PREDICATE_LE = 3, + MLIR_EMITC_CMP_PREDICATE_GT = 4, + MLIR_EMITC_CMP_PREDICATE_GE = 5, + MLIR_EMITC_CMP_PREDICATE_THREE_WAY = 6, +}; + +//===---------------------------------------------------------------------===// +// ArrayType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCArrayType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCArrayTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCArrayTypeGet(intptr_t nDims, + int64_t *shape, + MlirType elementType); + +//===---------------------------------------------------------------------===// +// LValueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCLValueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCLValueTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCLValueTypeGet(MlirType valueType); + +//===---------------------------------------------------------------------===// +// OpaqueType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCOpaqueType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCOpaqueTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx, + MlirStringRef value); + +//===---------------------------------------------------------------------===// +// PointerType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCPointerType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPointerTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCPointerTypeGet(MlirType pointee); + +//===---------------------------------------------------------------------===// +// PtrDiffTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCPtrDiffTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCPtrDiffTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// SignedSizeTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCSignedSizeTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSignedSizeTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx); + +//===---------------------------------------------------------------------===// +// SizeTType +//===---------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirTypeIsAEmitCSizeTType(MlirType type); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCSizeTTypeGetTypeID(void); + +MLIR_CAPI_EXPORTED MlirType mlirEmitCSizeTTypeGet(MlirContext ctx); + +//===----------------------------------------------------------------------===// +// CmpPredicate attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAEmitCCmpPredicate(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute +mlirEmitCCmpPredicateAttrGet(MlirContext ctx, enum MlirEmitCCmpPredicate val); + +MLIR_CAPI_EXPORTED enum MlirEmitCCmpPredicate +mlirEmitCCmpPredicateAttrGetValue(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCCmpPredicateAttrGetTypeID(void); + +//===----------------------------------------------------------------------===// +// Opaque attribute. +//===----------------------------------------------------------------------===// + +MLIR_CAPI_EXPORTED bool mlirAttributeIsAEmitCOpaque(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirAttribute mlirEmitCOpaqueAttrGet(MlirContext ctx, + MlirStringRef value); + +MLIR_CAPI_EXPORTED MlirStringRef +mlirEmitCOpaqueAttrGetValue(MlirAttribute attr); + +MLIR_CAPI_EXPORTED MlirTypeID mlirEmitCOpaqueAttrGetTypeID(void); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/EmitC.cpp b/mlir/lib/CAPI/Dialect/EmitC.cpp index 3dcb7038a..b6d197366 100644 --- a/mlir/lib/CAPI/Dialect/EmitC.cpp +++ b/mlir/lib/CAPI/Dialect/EmitC.cpp @@ -10,4 +10,180 @@ #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" +using namespace mlir; + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(EmitC, emitc, mlir::emitc::EmitCDialect) + +// Ensure the C-API enums are uint64_t-castable to C++ equivalents. +static_assert(static_cast(MLIR_EMITC_CMP_PREDICATE_EQ) == + static_cast(emitc::CmpPredicate::eq) && + static_cast(MLIR_EMITC_CMP_PREDICATE_NE) == + static_cast(emitc::CmpPredicate::ne) && + static_cast(MLIR_EMITC_CMP_PREDICATE_LT) == + static_cast(emitc::CmpPredicate::lt) && + static_cast(MLIR_EMITC_CMP_PREDICATE_LE) == + static_cast(emitc::CmpPredicate::le) && + static_cast(MLIR_EMITC_CMP_PREDICATE_GT) == + static_cast(emitc::CmpPredicate::gt) && + static_cast(MLIR_EMITC_CMP_PREDICATE_GE) == + static_cast(emitc::CmpPredicate::ge) && + static_cast(MLIR_EMITC_CMP_PREDICATE_THREE_WAY) == + static_cast(emitc::CmpPredicate::three_way), + "MlirEmitCCmpPredicate (C-API) and CmpPredicate (C++) mismatch"); + +//===---------------------------------------------------------------------===// +// ArrayType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCArrayType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCArrayTypeGetTypeID(void) { + return wrap(emitc::ArrayType::getTypeID()); +} + +MlirType mlirEmitCArrayTypeGet(intptr_t nDims, int64_t *shape, + MlirType elementType) { + return wrap( + emitc::ArrayType::get(llvm::ArrayRef(shape, nDims), unwrap(elementType))); +} + +//===---------------------------------------------------------------------===// +// LValueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCLValueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCLValueTypeGetTypeID(void) { + return wrap(emitc::LValueType::getTypeID()); +} + +MlirType mlirEmitCLValueTypeGet(MlirType valueType) { + return wrap(emitc::LValueType::get(unwrap(valueType))); +} + +//===---------------------------------------------------------------------===// +// OpaqueType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCOpaqueType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCOpaqueTypeGetTypeID(void) { + return wrap(emitc::OpaqueType::getTypeID()); +} + +MlirType mlirEmitCOpaqueTypeGet(MlirContext ctx, MlirStringRef value) { + return wrap(emitc::OpaqueType::get(unwrap(ctx), unwrap(value))); +} + +//===---------------------------------------------------------------------===// +// PointerType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCPointerType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCPointerTypeGetTypeID(void) { + return wrap(emitc::PointerType::getTypeID()); +} + +MlirType mlirEmitCPointerTypeGet(MlirType pointee) { + return wrap(emitc::PointerType::get(unwrap(pointee))); +} + +//===---------------------------------------------------------------------===// +// PtrDiffTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCPtrDiffTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCPtrDiffTTypeGetTypeID(void) { + return wrap(emitc::PtrDiffTType::getTypeID()); +} + +MlirType mlirEmitCPtrDiffTTypeGet(MlirContext ctx) { + return wrap(emitc::PtrDiffTType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// SignedSizeTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCSignedSizeTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCSignedSizeTTypeGetTypeID(void) { + return wrap(emitc::SignedSizeTType::getTypeID()); +} + +MlirType mlirEmitCSignedSizeTTypeGet(MlirContext ctx) { + return wrap(emitc::SignedSizeTType::get(unwrap(ctx))); +} + +//===---------------------------------------------------------------------===// +// SizeTType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAEmitCSizeTType(MlirType type) { + return isa(unwrap(type)); +} + +MlirTypeID mlirEmitCSizeTTypeGetTypeID(void) { + return wrap(emitc::SizeTType::getTypeID()); +} + +MlirType mlirEmitCSizeTTypeGet(MlirContext ctx) { + return wrap(emitc::SizeTType::get(unwrap(ctx))); +} + +//===----------------------------------------------------------------------===// +// CmpPredicate attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAEmitCCmpPredicate(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirEmitCCmpPredicateAttrGet(MlirContext ctx, + MlirEmitCCmpPredicate val) { + return wrap((Attribute)emitc::CmpPredicateAttr::get( + unwrap(ctx), static_cast(val))); +} + +MlirEmitCCmpPredicate mlirEmitCCmpPredicateAttrGetValue(MlirAttribute attr) { + return static_cast( + llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirEmitCCmpPredicateAttrGetTypeID(void) { + return wrap(emitc::CmpPredicateAttr::getTypeID()); +} + +//===----------------------------------------------------------------------===// +// Opaque attribute. +//===----------------------------------------------------------------------===// + +bool mlirAttributeIsAEmitCOpaque(MlirAttribute attr) { + return llvm::isa(unwrap(attr)); +} + +MlirAttribute mlirEmitCOpaqueAttrGet(MlirContext ctx, MlirStringRef value) { + return wrap((Attribute)emitc::OpaqueAttr::get(unwrap(ctx), unwrap(value))); +} + +MlirStringRef mlirEmitCOpaqueAttrGetValue(MlirAttribute attr) { + return wrap(llvm::cast(unwrap(attr)).getValue()); +} + +MlirTypeID mlirEmitCOpaqueAttrGetTypeID(void) { + return wrap(emitc::OpaqueAttr::getTypeID()); +} From a3068ffaf2ade62d83740a44b8a87e0752ec2a2d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Thu, 16 Jan 2025 08:56:09 +0100 Subject: [PATCH 823/915] [mlir][IR] Remove factory methods from `FloatType` (#123026) This commit removes convenience methods from `FloatType` to make it independent of concrete interface implementations. See discussion here: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361 Note for LLVM integration: Replace `FloatType::getF32(` with `Float32Type::get(` etc. --- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 32 +++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 252ff54af..250e4a6bb 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -94,7 +94,7 @@ bool mlirTypeIsAFloat4E2M1FN(MlirType type) { } MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat4E2M1FN(unwrap(ctx))); + return wrap(Float4E2M1FNType::get(unwrap(ctx))); } MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { @@ -106,7 +106,7 @@ bool mlirTypeIsAFloat6E2M3FN(MlirType type) { } MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat6E2M3FN(unwrap(ctx))); + return wrap(Float6E2M3FNType::get(unwrap(ctx))); } MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { @@ -118,7 +118,7 @@ bool mlirTypeIsAFloat6E3M2FN(MlirType type) { } MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat6E3M2FN(unwrap(ctx))); + return wrap(Float6E3M2FNType::get(unwrap(ctx))); } MlirTypeID mlirFloat8E5M2TypeGetTypeID() { @@ -130,7 +130,7 @@ bool mlirTypeIsAFloat8E5M2(MlirType type) { } MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E5M2(unwrap(ctx))); + return wrap(Float8E5M2Type::get(unwrap(ctx))); } MlirTypeID mlirFloat8E4M3TypeGetTypeID() { @@ -142,7 +142,7 @@ bool mlirTypeIsAFloat8E4M3(MlirType type) { } MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E4M3(unwrap(ctx))); + return wrap(Float8E4M3Type::get(unwrap(ctx))); } MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { @@ -154,7 +154,7 @@ bool mlirTypeIsAFloat8E4M3FN(MlirType type) { } MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E4M3FN(unwrap(ctx))); + return wrap(Float8E4M3FNType::get(unwrap(ctx))); } MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { @@ -166,7 +166,7 @@ bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { } MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E5M2FNUZ(unwrap(ctx))); + return wrap(Float8E5M2FNUZType::get(unwrap(ctx))); } MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { @@ -178,7 +178,7 @@ bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { } MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E4M3FNUZ(unwrap(ctx))); + return wrap(Float8E4M3FNUZType::get(unwrap(ctx))); } MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { @@ -190,7 +190,7 @@ bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { } MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E4M3B11FNUZ(unwrap(ctx))); + return wrap(Float8E4M3B11FNUZType::get(unwrap(ctx))); } MlirTypeID mlirFloat8E3M4TypeGetTypeID() { @@ -202,7 +202,7 @@ bool mlirTypeIsAFloat8E3M4(MlirType type) { } MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E3M4(unwrap(ctx))); + return wrap(Float8E3M4Type::get(unwrap(ctx))); } MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { @@ -214,7 +214,7 @@ bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { } MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { - return wrap(FloatType::getFloat8E8M0FNU(unwrap(ctx))); + return wrap(Float8E8M0FNUType::get(unwrap(ctx))); } MlirTypeID mlirBFloat16TypeGetTypeID() { @@ -224,7 +224,7 @@ MlirTypeID mlirBFloat16TypeGetTypeID() { bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } MlirType mlirBF16TypeGet(MlirContext ctx) { - return wrap(FloatType::getBF16(unwrap(ctx))); + return wrap(BFloat16Type::get(unwrap(ctx))); } MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } @@ -232,7 +232,7 @@ MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } MlirType mlirF16TypeGet(MlirContext ctx) { - return wrap(FloatType::getF16(unwrap(ctx))); + return wrap(Float16Type::get(unwrap(ctx))); } MlirTypeID mlirFloatTF32TypeGetTypeID() { @@ -242,7 +242,7 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() { bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); } MlirType mlirTF32TypeGet(MlirContext ctx) { - return wrap(FloatType::getTF32(unwrap(ctx))); + return wrap(FloatTF32Type::get(unwrap(ctx))); } MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } @@ -250,7 +250,7 @@ MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } MlirType mlirF32TypeGet(MlirContext ctx) { - return wrap(FloatType::getF32(unwrap(ctx))); + return wrap(Float32Type::get(unwrap(ctx))); } MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } @@ -258,7 +258,7 @@ MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } MlirType mlirF64TypeGet(MlirContext ctx) { - return wrap(FloatType::getF64(unwrap(ctx))); + return wrap(Float64Type::get(unwrap(ctx))); } //===----------------------------------------------------------------------===// From 3f52ebce03f198a5d9bdd23798ad3e2f1d0a549d Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 20 Jan 2025 09:22:53 +0100 Subject: [PATCH 824/915] [mlir][IR] Remove `isF...()` type API for low-precision FP types (#123326) Remove `type.isFloat4E2M1FN()` etc. Use `isa(type)` instead. For details, see: https://discourse.llvm.org/t/rethink-on-approach-to-low-precision-fp-types/82361/28 --- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 42 +++++++++++++++++++------------ 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 250e4a6bb..98ca9c3d2 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -90,7 +90,7 @@ MlirTypeID mlirFloat4E2M1FNTypeGetTypeID() { } bool mlirTypeIsAFloat4E2M1FN(MlirType type) { - return unwrap(type).isFloat4E2M1FN(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat4E2M1FNTypeGet(MlirContext ctx) { @@ -102,7 +102,7 @@ MlirTypeID mlirFloat6E2M3FNTypeGetTypeID() { } bool mlirTypeIsAFloat6E2M3FN(MlirType type) { - return unwrap(type).isFloat6E2M3FN(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat6E2M3FNTypeGet(MlirContext ctx) { @@ -114,7 +114,7 @@ MlirTypeID mlirFloat6E3M2FNTypeGetTypeID() { } bool mlirTypeIsAFloat6E3M2FN(MlirType type) { - return unwrap(type).isFloat6E3M2FN(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat6E3M2FNTypeGet(MlirContext ctx) { @@ -126,7 +126,7 @@ MlirTypeID mlirFloat8E5M2TypeGetTypeID() { } bool mlirTypeIsAFloat8E5M2(MlirType type) { - return unwrap(type).isFloat8E5M2(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E5M2TypeGet(MlirContext ctx) { @@ -138,7 +138,7 @@ MlirTypeID mlirFloat8E4M3TypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3(MlirType type) { - return unwrap(type).isFloat8E4M3(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E4M3TypeGet(MlirContext ctx) { @@ -150,7 +150,7 @@ MlirTypeID mlirFloat8E4M3FNTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3FN(MlirType type) { - return unwrap(type).isFloat8E4M3FN(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E4M3FNTypeGet(MlirContext ctx) { @@ -162,7 +162,7 @@ MlirTypeID mlirFloat8E5M2FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E5M2FNUZ(MlirType type) { - return unwrap(type).isFloat8E5M2FNUZ(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E5M2FNUZTypeGet(MlirContext ctx) { @@ -174,7 +174,7 @@ MlirTypeID mlirFloat8E4M3FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3FNUZ(MlirType type) { - return unwrap(type).isFloat8E4M3FNUZ(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E4M3FNUZTypeGet(MlirContext ctx) { @@ -186,7 +186,7 @@ MlirTypeID mlirFloat8E4M3B11FNUZTypeGetTypeID() { } bool mlirTypeIsAFloat8E4M3B11FNUZ(MlirType type) { - return unwrap(type).isFloat8E4M3B11FNUZ(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E4M3B11FNUZTypeGet(MlirContext ctx) { @@ -198,7 +198,7 @@ MlirTypeID mlirFloat8E3M4TypeGetTypeID() { } bool mlirTypeIsAFloat8E3M4(MlirType type) { - return unwrap(type).isFloat8E3M4(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E3M4TypeGet(MlirContext ctx) { @@ -210,7 +210,7 @@ MlirTypeID mlirFloat8E8M0FNUTypeGetTypeID() { } bool mlirTypeIsAFloat8E8M0FNU(MlirType type) { - return unwrap(type).isFloat8E8M0FNU(); + return llvm::isa(unwrap(type)); } MlirType mlirFloat8E8M0FNUTypeGet(MlirContext ctx) { @@ -221,7 +221,9 @@ MlirTypeID mlirBFloat16TypeGetTypeID() { return wrap(BFloat16Type::getTypeID()); } -bool mlirTypeIsABF16(MlirType type) { return unwrap(type).isBF16(); } +bool mlirTypeIsABF16(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirBF16TypeGet(MlirContext ctx) { return wrap(BFloat16Type::get(unwrap(ctx))); @@ -229,7 +231,9 @@ MlirType mlirBF16TypeGet(MlirContext ctx) { MlirTypeID mlirFloat16TypeGetTypeID() { return wrap(Float16Type::getTypeID()); } -bool mlirTypeIsAF16(MlirType type) { return unwrap(type).isF16(); } +bool mlirTypeIsAF16(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF16TypeGet(MlirContext ctx) { return wrap(Float16Type::get(unwrap(ctx))); @@ -239,7 +243,9 @@ MlirTypeID mlirFloatTF32TypeGetTypeID() { return wrap(FloatTF32Type::getTypeID()); } -bool mlirTypeIsATF32(MlirType type) { return unwrap(type).isTF32(); } +bool mlirTypeIsATF32(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirTF32TypeGet(MlirContext ctx) { return wrap(FloatTF32Type::get(unwrap(ctx))); @@ -247,7 +253,9 @@ MlirType mlirTF32TypeGet(MlirContext ctx) { MlirTypeID mlirFloat32TypeGetTypeID() { return wrap(Float32Type::getTypeID()); } -bool mlirTypeIsAF32(MlirType type) { return unwrap(type).isF32(); } +bool mlirTypeIsAF32(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF32TypeGet(MlirContext ctx) { return wrap(Float32Type::get(unwrap(ctx))); @@ -255,7 +263,9 @@ MlirType mlirF32TypeGet(MlirContext ctx) { MlirTypeID mlirFloat64TypeGetTypeID() { return wrap(Float64Type::getTypeID()); } -bool mlirTypeIsAF64(MlirType type) { return unwrap(type).isF64(); } +bool mlirTypeIsAF64(MlirType type) { + return llvm::isa(unwrap(type)); +} MlirType mlirF64TypeGet(MlirContext ctx) { return wrap(Float64Type::get(unwrap(ctx))); From c76d2971f945fe30cd02941d1ff4c0fc3a734c14 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 21 Jan 2025 08:48:09 +0100 Subject: [PATCH 825/915] [mlir][IR][NFC] Move free-standing functions to `MemRefType` (#123465) Turn free-standing `MemRefType`-related helper functions in `BuiltinTypes.h` into member functions. --- mlir/lib/CAPI/IR/BuiltinTypes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index 98ca9c3d2..a080adf0f 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -524,7 +524,7 @@ MlirLogicalResult mlirMemRefTypeGetStridesAndOffset(MlirType type, int64_t *offset) { MemRefType memrefType = llvm::cast(unwrap(type)); SmallVector strides_; - if (failed(getStridesAndOffset(memrefType, strides_, *offset))) + if (failed(memrefType.getStridesAndOffset(strides_, *offset))) return mlirLogicalResultFailure(); (void)std::copy(strides_.begin(), strides_.end(), strides); From c945a1845caf072b1dada829c041eeb925f6ee42 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Tue, 21 Jan 2025 21:23:32 -0800 Subject: [PATCH 826/915] [mlir][NFC] Avoid using braced initializer lists to call a constructor. (#123714) In the LLVM style guide, we prefer not using braced initializer lists to call a constructor. Also, we prefer using an equal before the open curly brace if we use a braced initializer list when initializing a variable. See https://llvm.org/docs/CodingStandards.html#do-not-use-braced-initializer-lists-to-call-a-constructor for more details. The style guide does not explain the reason well. There is an article from abseil, which mentions few benefits. E.g., we can avoid the most vexing parse, etc. See https://abseil.io/tips/88 for more details. Signed-off-by: hanhanW --- mlir/lib/Bindings/Python/IRAttributes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 08f7d4881..7bc21a31c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -845,7 +845,7 @@ class PyDenseElementsAttribute } shapedType = *explicitType; } else { - SmallVector shape{static_cast(numAttributes)}; + SmallVector shape = {static_cast(numAttributes)}; shapedType = mlirRankedTensorTypeGet( shape.size(), shape.data(), mlirAttributeGetType(pyTryCast(attributes[0])), From 9150732f85964c395a4706032fb66cbc0deef692 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Wed, 22 Jan 2025 08:54:54 +0100 Subject: [PATCH 827/915] [mlir] Link libraries that aren't included in libMLIR to libMLIR (#123781) Use `mlir_target_link_libraries()` to link dependencies of libraries that are not included in libMLIR, to ensure that they link to the dylib when they are used in Flang. Otherwise, they implicitly pull in all their static dependencies, effectively causing Flang binaries to simultaneously link to the dylib and to static libraries, which is never a good idea. I have only covered the libraries that are used by Flang. If you wish, I can extend this approach to all non-libMLIR libraries in MLIR, making MLIR itself also link to the dylib consistently. [v2 with fixed `-DBUILD_SHARED_LIBS=ON` build] --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 0be8f2af5..15c72a15a 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,6 +1,8 @@ set(LLVM_LINK_COMPONENTS nativecodegen native + orcjit + support ) # Main API shared library. From a46de9dbaa866f60df2ff19025e901bc35770f92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Wed, 22 Jan 2025 09:08:23 +0100 Subject: [PATCH 828/915] Revert "[mlir] Link libraries that aren't included in libMLIR to libMLIR (#123781)" This reverts commit 9150732f85964c395a4706032fb66cbc0deef692. More BUILD_SHARED_LIBS=ON regressions, sigh. --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 15c72a15a..0be8f2af5 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,8 +1,6 @@ set(LLVM_LINK_COMPONENTS nativecodegen native - orcjit - support ) # Main API shared library. From 1860fe801cdd1d3c59c1343bf6134836364f0c1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20G=C3=B3rny?= Date: Wed, 22 Jan 2025 10:01:50 +0100 Subject: [PATCH 829/915] Reapply "[mlir] Link libraries that aren't included in libMLIR to libMLIR" (#123910) Use `mlir_target_link_libraries()` to link dependencies of libraries that are not included in libMLIR, to ensure that they link to the dylib when they are used in Flang. Otherwise, they implicitly pull in all their static dependencies, effectively causing Flang binaries to simultaneously link to the dylib and to static libraries, which is never a good idea. I have only covered the libraries that are used by Flang. If you wish, I can extend this approach to all non-libMLIR libraries in MLIR, making MLIR itself also link to the dylib consistently. [v3 with more `-DBUILD_SHARED_LIBS=ON` fixes] --- mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt index 0be8f2af5..bf7dff897 100644 --- a/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt +++ b/mlir/lib/CAPI/ExecutionEngine/CMakeLists.txt @@ -1,8 +1,10 @@ set(LLVM_LINK_COMPONENTS nativecodegen native + orcjit + support ) - + # Main API shared library. add_mlir_upstream_c_api_library(MLIRCAPIExecutionEngine ExecutionEngine.cpp From dd78f4e5eb85fb9ed3d18c70f6772ee46fa53d91 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Jan 2025 08:41:31 -0500 Subject: [PATCH 830/915] [mlir:python] Small optimization to get_op_result_or_results. (#123866) * We can call .results without figuring out whether we have an Operation or an OpView, and that's likely the common case anyway. * If we have one or more results, we can return them directly, with no need for a call to get_op_result_or_value. We're guaranteed that .results returns a PyOpResultList, so we have either an OpResult or sequence of OpResults, just as the API expects. This saves a few 100ms during IR construction in an LLM JAX benchmark. --- mlir/python/mlir/dialects/_ods_common.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index d40d936cd..5b67ab03d 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -133,15 +133,16 @@ def get_op_results_or_values( def get_op_result_or_op_results( op: _Union[_cext.ir.OpView, _cext.ir.Operation], ) -> _Union[_cext.ir.Operation, _cext.ir.OpResult, _Sequence[_cext.ir.OpResult]]: - if isinstance(op, _cext.ir.OpView): - op = op.operation - return ( - list(get_op_results_or_values(op)) - if len(op.results) > 1 - else get_op_result_or_value(op) - if len(op.results) > 0 - else op - ) + results = op.results + num_results = len(results) + if num_results == 1: + return results[0] + elif num_results > 1: + return results + elif isinstance(op, _cext.ir.OpView): + return op.operation + else: + return op ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] From 3a284d4765bd5cbf821293dba5eeeef9449e0287 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Jan 2025 09:21:46 -0500 Subject: [PATCH 831/915] [mlir python] Change PyOpView constructor to construct operations. (#123777) Previously ODS-generated Python operations had code like this: ``` super().__init__(self.build_generic(attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) ``` we change it to: ``` super().__init__(self.OPERATION_NAME, self._ODS_REGIONS, self._ODS_OPERAND_SEGMENTS, self._ODS_RESULT_SEGMENTS, attributes=attributes, operands=operands, successors=_ods_successors, regions=regions, loc=loc, ip=ip) ``` This: a) avoids an extra call dispatch (to `build_generic`), and b) passes the class attributes directly to the constructor. Benchmarks show that it is faster to pass these as arguments rather than having the C++ code look up attributes on the class. This PR improves the timing of the following benchmark on my workstation from 5.3s to 4.5s: ``` def main(_): with ir.Context(), ir.Location.unknown(): typ = ir.IntegerType.get_signless(32) m = ir.Module.create() with ir.InsertionPoint(m.body): start = time.time() for i in range(1000000): arith.ConstantOp(typ, i) end = time.time() print(f"time: {end - start}") ``` Since this change adds an additional overload to the constructor and does not alter any existing behaviors, it should be backwards compatible. --- mlir/lib/Bindings/Python/IRCore.cpp | 73 ++++++++++++++++++++++++----- mlir/lib/Bindings/Python/IRModule.h | 18 ++++--- 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 53806ca9f..1c9fb3d2a 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -211,6 +211,10 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +static MlirStringRef toMlirStringRef(std::string_view s) { + return mlirStringRefCreate(s.data(), s.size()); +} + static MlirStringRef toMlirStringRef(const nb::bytes &s) { return mlirStringRefCreate(static_cast(s.data()), s.size()); } @@ -1460,7 +1464,7 @@ static void maybeInsertOperation(PyOperationRef &op, } } -nb::object PyOperation::create(const std::string &name, +nb::object PyOperation::create(std::string_view name, std::optional> results, std::optional> operands, std::optional attributes, @@ -1506,7 +1510,7 @@ nb::object PyOperation::create(const std::string &name, } catch (nb::cast_error &err) { std::string msg = "Invalid attribute key (not a string) when " "attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; + std::string(name) + "\" (" + err.what() + ")"; throw nb::type_error(msg.c_str()); } try { @@ -1516,13 +1520,14 @@ nb::object PyOperation::create(const std::string &name, } catch (nb::cast_error &err) { std::string msg = "Invalid attribute value for the key \"" + key + "\" when attempting to create the operation \"" + - name + "\" (" + err.what() + ")"; + std::string(name) + "\" (" + err.what() + ")"; throw nb::type_error(msg.c_str()); } catch (std::runtime_error &) { // This exception seems thrown when the value is "None". std::string msg = "Found an invalid (`None`?) attribute value for the key \"" + key + - "\" when attempting to create the operation \"" + name + "\""; + "\" when attempting to create the operation \"" + + std::string(name) + "\""; throw std::runtime_error(msg); } } @@ -1714,27 +1719,25 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, } nb::object PyOpView::buildGeneric( - const nb::object &cls, std::optional resultTypeList, - nb::list operandList, std::optional attributes, + std::string_view name, std::tuple opRegionSpec, + nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, + std::optional resultTypeList, nb::list operandList, + std::optional attributes, std::optional> successors, std::optional regions, DefaultingPyLocation location, const nb::object &maybeIp) { PyMlirContextRef context = location->getContext(); + // Class level operation construction metadata. - std::string name = nb::cast(cls.attr("OPERATION_NAME")); // Operand and result segment specs are either none, which does no // variadic unpacking, or a list of ints with segment sizes, where each // element is either a positive number (typically 1 for a scalar) or -1 to // indicate that it is derived from the length of the same-indexed operand // or result (implying that it is a list at that position). - nb::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS"); - nb::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS"); - std::vector operandSegmentLengths; std::vector resultSegmentLengths; // Validate/determine region count. - auto opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); int opMinRegionCount = std::get<0>(opRegionSpec); bool opHasNoVariadicRegions = std::get<1>(opRegionSpec); if (!regions) { @@ -3236,6 +3239,33 @@ void mlir::python::populateIRCore(nb::module_ &m) { auto opViewClass = nb::class_(m, "OpView") .def(nb::init(), nb::arg("operation")) + .def( + "__init__", + [](PyOpView *self, std::string_view name, + std::tuple opRegionSpec, + nb::object operandSegmentSpecObj, + nb::object resultSegmentSpecObj, + std::optional resultTypeList, nb::list operandList, + std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const nb::object &maybeIp) { + new (self) PyOpView(PyOpView::buildGeneric( + name, opRegionSpec, operandSegmentSpecObj, + resultSegmentSpecObj, resultTypeList, operandList, + attributes, successors, regions, location, maybeIp)); + }, + nb::arg("name"), nb::arg("opRegionSpec"), + nb::arg("operandSegmentSpecObj").none() = nb::none(), + nb::arg("resultSegmentSpecObj").none() = nb::none(), + nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), + nb::arg("regions").none() = nb::none(), + nb::arg("loc").none() = nb::none(), + nb::arg("ip").none() = nb::none()) + .def_prop_ro("operation", &PyOpView::getOperationObject) .def_prop_ro("opview", [](nb::object self) { return self; }) .def( @@ -3250,9 +3280,26 @@ void mlir::python::populateIRCore(nb::module_ &m) { opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true); opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none(); opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none(); + // It is faster to pass the operation_name, ods_regions, and + // ods_operand_segments/ods_result_segments as arguments to the constructor, + // rather than to access them as attributes. opViewClass.attr("build_generic") = classmethod( - &PyOpView::buildGeneric, nb::arg("cls"), - nb::arg("results").none() = nb::none(), + [](nb::handle cls, std::optional resultTypeList, + nb::list operandList, std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const nb::object &maybeIp) { + std::string name = nb::cast(cls.attr("OPERATION_NAME")); + std::tuple opRegionSpec = + nb::cast>(cls.attr("_ODS_REGIONS")); + nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); + nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); + return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, + resultSegmentSpec, resultTypeList, + operandList, attributes, successors, + regions, location, maybeIp); + }, + nb::arg("cls"), nb::arg("results").none() = nb::none(), nb::arg("operands").none() = nb::none(), nb::arg("attributes").none() = nb::none(), nb::arg("successors").none() = nb::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index d1fb4308d..2228b5523 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -685,7 +685,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an operation. See corresponding python docstring. static nanobind::object - create(const std::string &name, std::optional> results, + create(std::string_view name, std::optional> results, std::optional> operands, std::optional attributes, std::optional> successors, int regions, @@ -739,12 +739,16 @@ class PyOpView : public PyOperationBase { nanobind::object getOperationObject() { return operationObject; } - static nanobind::object buildGeneric( - const nanobind::object &cls, std::optional resultTypeList, - nanobind::list operandList, std::optional attributes, - std::optional> successors, - std::optional regions, DefaultingPyLocation location, - const nanobind::object &maybeIp); + static nanobind::object + buildGeneric(std::string_view name, std::tuple opRegionSpec, + nanobind::object operandSegmentSpecObj, + nanobind::object resultSegmentSpecObj, + std::optional resultTypeList, + nanobind::list operandList, + std::optional attributes, + std::optional> successors, + std::optional regions, DefaultingPyLocation location, + const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its /// `__init__` method. The derived class will typically define a constructor From 28fe787bde3fb0b82be8c99687d57bf05d41204a Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Wed, 22 Jan 2025 09:26:44 -0500 Subject: [PATCH 832/915] [mlir:python] Construct PyOperation objects in-place on the Python heap. (#123813) Currently we make two memory allocations for each PyOperation: a Python object, and the PyOperation class itself. With some care we can allocate the PyOperation inline inside the Python object, saving us a malloc() call per object and perhaps improving cache locality. --- mlir/lib/Bindings/Python/IRCore.cpp | 28 ++++++++++++++++++++-------- mlir/lib/Bindings/Python/IRModule.h | 3 ++- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1c9fb3d2a..c862ec84f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1199,21 +1199,33 @@ PyOperation::~PyOperation() { } } +namespace { + +// Constructs a new object of type T in-place on the Python heap, returning a +// PyObjectRef to it, loosely analogous to std::make_shared(). +template +PyObjectRef makeObjectRef(Args &&...args) { + nb::handle type = nb::type(); + nb::object instance = nb::inst_alloc(type); + T *ptr = nb::inst_ptr(instance); + new (ptr) T(std::forward(args)...); + nb::inst_mark_ready(instance); + return PyObjectRef(ptr, std::move(instance)); +} + +} // namespace + PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { // Create. - PyOperation *unownedOperation = - new PyOperation(std::move(contextRef), operation); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - nb::object pyRef = nb::cast(unownedOperation, nb::rv_policy::take_ownership); - unownedOperation->handle = pyRef; + PyOperationRef unownedOperation = + makeObjectRef(std::move(contextRef), operation); + unownedOperation->handle = unownedOperation.getObject(); if (parentKeepAlive) { unownedOperation->parentKeepAlive = std::move(parentKeepAlive); } - return PyOperationRef(unownedOperation, std::move(pyRef)); + return unownedOperation; } PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 2228b5523..fd70ac7ac 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -705,8 +705,9 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Clones this operation. nanobind::object clone(const nanobind::object &ip); -private: PyOperation(PyMlirContextRef contextRef, MlirOperation operation); + +private: static PyOperationRef createInstance(PyMlirContextRef contextRef, MlirOperation operation, nanobind::object parentKeepAlive); From cef6bf6eaf1a833a10645865e6482f8422d5d043 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 22 Jan 2025 14:33:19 -0800 Subject: [PATCH 833/915] [mlir] Add C and Python interface for file range (#123276) Plumbs through creating file ranges to C and Python. --- mlir/include/mlir-c/IR.h | 5 +++++ mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++++++++++++ mlir/lib/CAPI/IR/IR.cpp | 9 +++++++++ 3 files changed, 29 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 0a515bbea..7d2fd89e8 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -256,6 +256,11 @@ mlirLocationFromAttribute(MlirAttribute attribute); MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColGet( MlirContext context, MlirStringRef filename, unsigned line, unsigned col); +/// Creates an File/Line/Column range location owned by the given context. +MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet( + MlirContext context, MlirStringRef filename, unsigned start_line, + unsigned start_col, unsigned end_line, unsigned end_col); + /// Creates a call site location with a callee and a caller. MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index c862ec84f..738f1444b 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -50,6 +50,9 @@ static const char kContextGetCallSiteLocationDocstring[] = static const char kContextGetFileLocationDocstring[] = R"(Gets a Location representing a file, line and column)"; +static const char kContextGetFileRangeDocstring[] = + R"(Gets a Location representing a file, line and column range)"; + static const char kContextGetFusedLocationDocstring[] = R"(Gets a Location representing a fused location with optional metadata)"; @@ -2917,6 +2920,18 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("filename"), nb::arg("line"), nb::arg("col"), nb::arg("context").none() = nb::none(), kContextGetFileLocationDocstring) + .def_static( + "file", + [](std::string filename, int startLine, int startCol, int endLine, + int endCol, DefaultingPyMlirContext context) { + return PyLocation(context->getRef(), + mlirLocationFileLineColRangeGet( + context->get(), toMlirStringRef(filename), + startLine, startCol, endLine, endCol)); + }, + nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), + nb::arg("end_line"), nb::arg("end_col"), + nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) .def_static( "fused", [](const std::vector &pyLocations, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 24dc88540..f27af0ca9 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -264,6 +264,15 @@ MlirLocation mlirLocationFileLineColGet(MlirContext context, FileLineColLoc::get(unwrap(context), unwrap(filename), line, col))); } +MlirLocation +mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, + unsigned startLine, unsigned startCol, + unsigned endLine, unsigned endCol) { + return wrap( + Location(FileLineColRange::get(unwrap(context), unwrap(filename), + startLine, startCol, endLine, endCol))); +} + MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); } From 2c01b1a0c8d77eddca9a99a649e5f1edaec836fe Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 24 Jan 2025 09:26:28 -0500 Subject: [PATCH 834/915] [mlir:python] Compute get_op_result_or_value in PyOpView's constructor. (#123953) This logic is in the critical path for constructing an operation from Python. It is faster to compute this in C++ than it is in Python, and it is a minor change to do this. This change also alters the API contract of _ods_common.get_op_results_or_values to avoid calling get_op_result_or_value on each element of a sequence, since the C++ code will now do this. Most of the diff here is simply reordering the code in IRCore.cpp. --- mlir/lib/Bindings/Python/IRCore.cpp | 432 ++++++++++++----------- mlir/lib/Bindings/Python/IRModule.h | 2 +- mlir/python/mlir/dialects/_ods_common.py | 7 +- 3 files changed, 239 insertions(+), 202 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 738f1444b..8e351cb22 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1481,12 +1481,11 @@ static void maybeInsertOperation(PyOperationRef &op, nb::object PyOperation::create(std::string_view name, std::optional> results, - std::optional> operands, + llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, const nb::object &maybeIp, bool inferType) { - llvm::SmallVector mlirOperands; llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; llvm::SmallVector, 4> mlirAttributes; @@ -1495,16 +1494,6 @@ nb::object PyOperation::create(std::string_view name, if (regions < 0) throw nb::value_error("number of regions must be >= 0"); - // Unpack/validate operands. - if (operands) { - mlirOperands.reserve(operands->size()); - for (PyValue *operand : *operands) { - if (!operand) - throw nb::value_error("operand value cannot be None"); - mlirOperands.push_back(operand->get()); - } - } - // Unpack/validate results. if (results) { mlirResults.reserve(results->size()); @@ -1562,9 +1551,8 @@ nb::object PyOperation::create(std::string_view name, // point, exceptions cannot be thrown or else the state will leak. MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), location); - if (!mlirOperands.empty()) - mlirOperationStateAddOperands(&state, mlirOperands.size(), - mlirOperands.data()); + if (!operands.empty()) + mlirOperationStateAddOperands(&state, operands.size(), operands.data()); state.enableResultTypeInference = inferType; if (!mlirResults.empty()) mlirOperationStateAddResults(&state, mlirResults.size(), @@ -1632,6 +1620,143 @@ void PyOperation::erase() { mlirOperationDestroy(operation); } +namespace { +/// CRTP base class for Python MLIR values that subclass Value and should be +/// castable from it. The value hierarchy is one level deep and is not supposed +/// to accommodate other levels unless core MLIR changes. +template +class PyConcreteValue : public PyValue { +public: + // Derived classes must define statics for: + // IsAFunctionTy isaFunction + // const char *pyClassName + // and redefine bindDerived. + using ClassTy = nb::class_; + using IsAFunctionTy = bool (*)(MlirValue); + + PyConcreteValue() = default; + PyConcreteValue(PyOperationRef operationRef, MlirValue value) + : PyValue(operationRef, value) {} + PyConcreteValue(PyValue &orig) + : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} + + /// Attempts to cast the original value to the derived type and throws on + /// type mismatches. + static MlirValue castFrom(PyValue &orig) { + if (!DerivedTy::isaFunction(orig.get())) { + auto origRepr = nb::cast(nb::repr(nb::cast(orig))); + throw nb::value_error((Twine("Cannot cast value to ") + + DerivedTy::pyClassName + " (from " + origRepr + + ")") + .str() + .c_str()); + } + return orig.get(); + } + + /// Binds the Python module objects to functions of this class. + static void bind(nb::module_ &m) { + auto cls = ClassTy(m, DerivedTy::pyClassName); + cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); + cls.def_static( + "isinstance", + [](PyValue &otherValue) -> bool { + return DerivedTy::isaFunction(otherValue); + }, + nb::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); + DerivedTy::bindDerived(cls); + } + + /// Implemented by derived classes to add methods to the Python subclass. + static void bindDerived(ClassTy &m) {} +}; + +} // namespace + +/// Python wrapper for MlirOpResult. +class PyOpResult : public PyConcreteValue { +public: + static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; + static constexpr const char *pyClassName = "OpResult"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("owner", [](PyOpResult &self) { + assert( + mlirOperationEqual(self.getParentOperation()->get(), + mlirOpResultGetOwner(self.get())) && + "expected the owner of the value in Python to match that in the IR"); + return self.getParentOperation().getObject(); + }); + c.def_prop_ro("result_number", [](PyOpResult &self) { + return mlirOpResultGetResultNumber(self.get()); + }); + } +}; + +/// Returns the list of types of the values held by container. +template +static std::vector getValueTypes(Container &container, + PyMlirContextRef &context) { + std::vector result; + result.reserve(container.size()); + for (int i = 0, e = container.size(); i < e; ++i) { + result.push_back(mlirValueGetType(container.getElement(i).get())); + } + return result; +} + +/// A list of operation results. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) result list is associated +/// with the operation whose results these are, and thus extends the lifetime of +/// this operation. +class PyOpResultList : public Sliceable { +public: + static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; + + PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumResults(operation->get()) + : length, + step), + operation(std::move(operation)) {} + + static void bindDerived(ClassTy &c) { + c.def_prop_ro("types", [](PyOpResultList &self) { + return getValueTypes(self, self.operation->getContext()); + }); + c.def_prop_ro("owner", [](PyOpResultList &self) { + return self.operation->createOpView(); + }); + } + + PyOperationRef &getOperation() { return operation; } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + operation->checkValid(); + return mlirOperationGetNumResults(operation->get()); + } + + PyOpResult getRawElement(intptr_t index) { + PyValue value(operation, mlirOperationGetResult(operation->get(), index)); + return PyOpResult(value); + } + + PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyOpResultList(operation, startIndex, length, step); + } + + PyOperationRef operation; +}; + //------------------------------------------------------------------------------ // PyOpView //------------------------------------------------------------------------------ @@ -1733,6 +1858,40 @@ static void populateResultTypes(StringRef name, nb::list resultTypeList, } } +static MlirValue getUniqueResult(MlirOperation operation) { + auto numResults = mlirOperationGetNumResults(operation); + if (numResults != 1) { + auto name = mlirIdentifierStr(mlirOperationGetName(operation)); + throw nb::value_error((Twine("Cannot call .result on operation ") + + StringRef(name.data, name.length) + " which has " + + Twine(numResults) + + " results (it is only valid for operations with a " + "single result)") + .str() + .c_str()); + } + return mlirOperationGetResult(operation, 0); +} + +static MlirValue getOpResultOrValue(nb::handle operand) { + if (operand.is_none()) { + throw nb::value_error("contained a None item"); + } + PyOperationBase *op; + if (nb::try_cast(operand, op)) { + return getUniqueResult(op->getOperation()); + } + PyOpResultList *opResultList; + if (nb::try_cast(operand, opResultList)) { + return getUniqueResult(opResultList->getOperation()->get()); + } + PyValue *value; + if (nb::try_cast(operand, value)) { + return value->get(); + } + throw nb::value_error("is not a Value"); +} + nb::object PyOpView::buildGeneric( std::string_view name, std::tuple opRegionSpec, nb::object operandSegmentSpecObj, nb::object resultSegmentSpecObj, @@ -1783,16 +1942,14 @@ nb::object PyOpView::buildGeneric( } // Unpack operands. - std::vector operands; + llvm::SmallVector operands; operands.reserve(operands.size()); if (operandSegmentSpecObj.is_none()) { // Non-sized operand unpacking. for (const auto &it : llvm::enumerate(operandList)) { try { - operands.push_back(nb::cast(it.value())); - if (!operands.back()) - throw nb::cast_error(); - } catch (nb::cast_error &err) { + operands.push_back(getOpResultOrValue(it.value())); + } catch (nb::builtin_exception &err) { throw nb::value_error((llvm::Twine("Operand ") + llvm::Twine(it.index()) + " of operation \"" + name + "\" must be a Value (" + err.what() + ")") @@ -1818,29 +1975,31 @@ nb::object PyOpView::buildGeneric( int segmentSpec = std::get<1>(it.value()); if (segmentSpec == 1 || segmentSpec == 0) { // Unpack unary element. - try { - auto *operandValue = nb::cast(std::get<0>(it.value())); - if (operandValue) { - operands.push_back(operandValue); - operandSegmentLengths.push_back(1); - } else if (segmentSpec == 0) { - // Allowed to be optional. - operandSegmentLengths.push_back(0); - } else { - throw nb::value_error( - (llvm::Twine("Operand ") + llvm::Twine(it.index()) + - " of operation \"" + name + - "\" must be a Value (was None and operand is not optional)") - .str() - .c_str()); + auto &operand = std::get<0>(it.value()); + if (!operand.is_none()) { + try { + + operands.push_back(getOpResultOrValue(operand)); + } catch (nb::builtin_exception &err) { + throw nb::value_error((llvm::Twine("Operand ") + + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (" + err.what() + ")") + .str() + .c_str()); } - } catch (nb::cast_error &err) { - throw nb::value_error((llvm::Twine("Operand ") + - llvm::Twine(it.index()) + " of operation \"" + - name + "\" must be a Value (" + err.what() + - ")") - .str() - .c_str()); + + operandSegmentLengths.push_back(1); + } else if (segmentSpec == 0) { + // Allowed to be optional. + operandSegmentLengths.push_back(0); + } else { + throw nb::value_error( + (llvm::Twine("Operand ") + llvm::Twine(it.index()) + + " of operation \"" + name + + "\" must be a Value (was None and operand is not optional)") + .str() + .c_str()); } } else if (segmentSpec == -1) { // Unpack sequence by appending. @@ -1852,10 +2011,7 @@ nb::object PyOpView::buildGeneric( // Unpack the list. auto segment = nb::cast(std::get<0>(it.value())); for (nb::handle segmentItem : segment) { - operands.push_back(nb::cast(segmentItem)); - if (!operands.back()) { - throw nb::type_error("contained a None item"); - } + operands.push_back(getOpResultOrValue(segmentItem)); } operandSegmentLengths.push_back(nb::len(segment)); } @@ -2269,57 +2425,6 @@ void PySymbolTable::walkSymbolTables(PyOperationBase &from, } namespace { -/// CRTP base class for Python MLIR values that subclass Value and should be -/// castable from it. The value hierarchy is one level deep and is not supposed -/// to accommodate other levels unless core MLIR changes. -template -class PyConcreteValue : public PyValue { -public: - // Derived classes must define statics for: - // IsAFunctionTy isaFunction - // const char *pyClassName - // and redefine bindDerived. - using ClassTy = nb::class_; - using IsAFunctionTy = bool (*)(MlirValue); - - PyConcreteValue() = default; - PyConcreteValue(PyOperationRef operationRef, MlirValue value) - : PyValue(operationRef, value) {} - PyConcreteValue(PyValue &orig) - : PyConcreteValue(orig.getParentOperation(), castFrom(orig)) {} - - /// Attempts to cast the original value to the derived type and throws on - /// type mismatches. - static MlirValue castFrom(PyValue &orig) { - if (!DerivedTy::isaFunction(orig.get())) { - auto origRepr = nb::cast(nb::repr(nb::cast(orig))); - throw nb::value_error((Twine("Cannot cast value to ") + - DerivedTy::pyClassName + " (from " + origRepr + - ")") - .str() - .c_str()); - } - return orig.get(); - } - - /// Binds the Python module objects to functions of this class. - static void bind(nb::module_ &m) { - auto cls = ClassTy(m, DerivedTy::pyClassName); - cls.def(nb::init(), nb::keep_alive<0, 1>(), nb::arg("value")); - cls.def_static( - "isinstance", - [](PyValue &otherValue) -> bool { - return DerivedTy::isaFunction(otherValue); - }, - nb::arg("other_value")); - cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](DerivedTy &self) { return self.maybeDownCast(); }); - DerivedTy::bindDerived(cls); - } - - /// Implemented by derived classes to add methods to the Python subclass. - static void bindDerived(ClassTy &m) {} -}; /// Python wrapper for MlirBlockArgument. class PyBlockArgument : public PyConcreteValue { @@ -2345,39 +2450,6 @@ class PyBlockArgument : public PyConcreteValue { } }; -/// Python wrapper for MlirOpResult. -class PyOpResult : public PyConcreteValue { -public: - static constexpr IsAFunctionTy isaFunction = mlirValueIsAOpResult; - static constexpr const char *pyClassName = "OpResult"; - using PyConcreteValue::PyConcreteValue; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro("owner", [](PyOpResult &self) { - assert( - mlirOperationEqual(self.getParentOperation()->get(), - mlirOpResultGetOwner(self.get())) && - "expected the owner of the value in Python to match that in the IR"); - return self.getParentOperation().getObject(); - }); - c.def_prop_ro("result_number", [](PyOpResult &self) { - return mlirOpResultGetResultNumber(self.get()); - }); - } -}; - -/// Returns the list of types of the values held by container. -template -static std::vector getValueTypes(Container &container, - PyMlirContextRef &context) { - std::vector result; - result.reserve(container.size()); - for (int i = 0, e = container.size(); i < e; ++i) { - result.push_back(mlirValueGetType(container.getElement(i).get())); - } - return result; -} - /// A list of block arguments. Internally, these are stored as consecutive /// elements, random access is cheap. The argument list is associated with the /// operation that contains the block (detached blocks are not allowed in @@ -2484,53 +2556,6 @@ class PyOpOperandList : public Sliceable { PyOperationRef operation; }; -/// A list of operation results. Internally, these are stored as consecutive -/// elements, random access is cheap. The (returned) result list is associated -/// with the operation whose results these are, and thus extends the lifetime of -/// this operation. -class PyOpResultList : public Sliceable { -public: - static constexpr const char *pyClassName = "OpResultList"; - using SliceableT = Sliceable; - - PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, - intptr_t length = -1, intptr_t step = 1) - : Sliceable(startIndex, - length == -1 ? mlirOperationGetNumResults(operation->get()) - : length, - step), - operation(std::move(operation)) {} - - static void bindDerived(ClassTy &c) { - c.def_prop_ro("types", [](PyOpResultList &self) { - return getValueTypes(self, self.operation->getContext()); - }); - c.def_prop_ro("owner", [](PyOpResultList &self) { - return self.operation->createOpView(); - }); - } - -private: - /// Give the parent CRTP class access to hook implementations below. - friend class Sliceable; - - intptr_t getRawNumElements() { - operation->checkValid(); - return mlirOperationGetNumResults(operation->get()); - } - - PyOpResult getRawElement(intptr_t index) { - PyValue value(operation, mlirOperationGetResult(operation->get(), index)); - return PyOpResult(value); - } - - PyOpResultList slice(intptr_t startIndex, intptr_t length, intptr_t step) { - return PyOpResultList(operation, startIndex, length, step); - } - - PyOperationRef operation; -}; - /// A list of operation successors. Internally, these are stored as consecutive /// elements, random access is cheap. The (returned) successor list is /// associated with the operation whose successors these are, and thus extends @@ -3123,20 +3148,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { "result", [](PyOperationBase &self) { auto &operation = self.getOperation(); - auto numResults = mlirOperationGetNumResults(operation); - if (numResults != 1) { - auto name = mlirIdentifierStr(mlirOperationGetName(operation)); - throw nb::value_error( - (Twine("Cannot call .result on operation ") + - StringRef(name.data, name.length) + " which has " + - Twine(numResults) + - " results (it is only valid for operations with a " - "single result)") - .str() - .c_str()); - } - return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)) + return PyOpResult(operation.getRef(), getUniqueResult(operation)) .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " @@ -3233,14 +3245,36 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("walk_order") = MlirWalkPostOrder); nb::class_(m, "Operation") - .def_static("create", &PyOperation::create, nb::arg("name"), - nb::arg("results").none() = nb::none(), - nb::arg("operands").none() = nb::none(), - nb::arg("attributes").none() = nb::none(), - nb::arg("successors").none() = nb::none(), - nb::arg("regions") = 0, nb::arg("loc").none() = nb::none(), - nb::arg("ip").none() = nb::none(), - nb::arg("infer_type") = false, kOperationCreateDocstring) + .def_static( + "create", + [](std::string_view name, + std::optional> results, + std::optional> operands, + std::optional attributes, + std::optional> successors, int regions, + DefaultingPyLocation location, const nb::object &maybeIp, + bool inferType) { + // Unpack/validate operands. + llvm::SmallVector mlirOperands; + if (operands) { + mlirOperands.reserve(operands->size()); + for (PyValue *operand : *operands) { + if (!operand) + throw nb::value_error("operand value cannot be None"); + mlirOperands.push_back(operand->get()); + } + } + + return PyOperation::create(name, results, mlirOperands, attributes, + successors, regions, location, maybeIp, + inferType); + }, + nb::arg("name"), nb::arg("results").none() = nb::none(), + nb::arg("operands").none() = nb::none(), + nb::arg("attributes").none() = nb::none(), + nb::arg("successors").none() = nb::none(), nb::arg("regions") = 0, + nb::arg("loc").none() = nb::none(), nb::arg("ip").none() = nb::none(), + nb::arg("infer_type") = false, kOperationCreateDocstring) .def_static( "parse", [](const std::string &sourceStr, const std::string &sourceName, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index fd70ac7ac..dd6e7ef91 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -686,7 +686,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { /// Creates an operation. See corresponding python docstring. static nanobind::object create(std::string_view name, std::optional> results, - std::optional> operands, + llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, DefaultingPyLocation location, const nanobind::object &ip, diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 5b67ab03d..d3dbdc604 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -115,7 +115,10 @@ def get_op_results_or_values( _cext.ir.Operation, _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], ] -) -> _Union[_Sequence[_cext.ir.Value], _cext.ir.OpResultList]: +) -> _Union[ + _Sequence[_Union[_cext.ir.OpView, _cext.ir.Operation, _cext.ir.Value]], + _cext.ir.OpResultList, +]: """Returns the given sequence of values or the results of the given op. This is useful to implement op constructors so that they can take other ops as @@ -127,7 +130,7 @@ def get_op_results_or_values( elif isinstance(arg, _cext.ir.Operation): return arg.results else: - return [get_op_result_or_value(element) for element in arg] + return arg def get_op_result_or_op_results( From e9e8aa2ab12ba0e20608e91a719f98b499c876ba Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 28 Jan 2025 11:02:26 -0600 Subject: [PATCH 835/915] [mlir][python] implement GenericOp bindings (#124496) --- mlir/python/mlir/dialects/linalg/__init__.py | 45 ++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 8fb1227ee..742262a9c 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -10,6 +10,7 @@ # DSL -> YAML -> tblgen -> pytblgen -> build/.../_linalg_ops_gen.py. from .._linalg_ops_gen import * from .._linalg_enum_gen import * +from .._linalg_enum_gen import _iteratortypeenum # These are the ground truth functions defined as: # ``` @@ -58,6 +59,7 @@ from ...ir import * from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from ...extras.meta import region_op def transpose( @@ -102,3 +104,46 @@ def broadcast( ) fill_builtin_region(op.operation) return op + + +@register_attribute_builder("IteratorTypeArrayAttr") +def _IteratorTypeArrayAttr(x, context): + return ArrayAttr.get([_iteratortypeenum(v, context) for v in x]) + + +# The underscore is needed here so that there's no collision with opdsl generation. +class GenericOp_(GenericOp): + def __init__( + self, + inputs, + outputs, + indexing_maps, + iterator_types, + *, + doc=None, + library_call=None, + loc=None, + ip=None, + ): + result_types = [] + if isinstance(outputs[0].type, RankedTensorType): + result_types = [o.type for o in outputs] + + super().__init__( + result_types, + inputs, + outputs, + indexing_maps, + iterator_types, + doc=doc, + library_call=library_call, + loc=loc, + ip=ip, + ) + element_types = [i.type.element_type for i in inputs] + [ + o.type.element_type for o in outputs + ] + self.regions[0].blocks.append(*element_types) + + +generic = region_op(GenericOp_, terminator=YieldOp) From 965673465a596efcd8a2648da6ebc939ad528450 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 18:31:58 +0100 Subject: [PATCH 836/915] [mlir][python] allow DenseIntElementsAttr for index type (#118947) Model the `IndexType` as `uint64_t` when converting to a python integer. With the python bindings, ```python DenseIntElementsAttr(op.attributes["attr"]) ``` used to `assert` when `attr` had `index` type like `dense<[1, 2, 3, 4]> : vector<4xindex>`. --------- Co-authored-by: Christopher McGirr Co-authored-by: Tiago Trevisan Jost --- mlir/include/mlir-c/BuiltinAttributes.h | 2 ++ mlir/lib/Bindings/Python/IRAttributes.cpp | 10 ++++++++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 3 +++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 7c8c84e55..1d0edf9ea 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED double diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 7bc21a31c..dcd098c20 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1372,13 +1372,19 @@ class PyDenseIntElementsAttribute MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); + // Index type can also appear as a DenseIntElementsAttr and therefore can be + // casted to integer. + assert(mlirTypeIsAInteger(type) || + mlirTypeIsAIndex(type) && "expected integer/index element type in " + "dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. + if (mlirTypeIsAIndex(type)) { + return mlirDenseElementsAttrGetIndexValue(*this, pos); + } unsigned width = mlirIntegerTypeGetWidth(type); bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 11d1ade55..8d57ab6b5 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } +uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } From 092254ace78fcf3a5de488377404d3ceab2822dc Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 28 Jan 2025 18:35:50 +0100 Subject: [PATCH 837/915] Revert "[mlir][python] allow DenseIntElementsAttr for index type (#118947)" This reverts commit 965673465a596efcd8a2648da6ebc939ad528450. --- mlir/include/mlir-c/BuiltinAttributes.h | 2 -- mlir/lib/Bindings/Python/IRAttributes.cpp | 10 ++-------- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 3 --- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 1d0edf9ea..7c8c84e55 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -556,8 +556,6 @@ MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); -MLIR_CAPI_EXPORTED uint64_t -mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED double diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index dcd098c20..7bc21a31c 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1372,19 +1372,13 @@ class PyDenseIntElementsAttribute MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); - // Index type can also appear as a DenseIntElementsAttr and therefore can be - // casted to integer. - assert(mlirTypeIsAInteger(type) || - mlirTypeIsAIndex(type) && "expected integer/index element type in " - "dense int elements attribute"); + assert(mlirTypeIsAInteger(type) && + "expected integer element type in dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. - if (mlirTypeIsAIndex(type)) { - return mlirDenseElementsAttrGetIndexValue(*this, pos); - } unsigned width = mlirIntegerTypeGetWidth(type); bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 8d57ab6b5..11d1ade55 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -758,9 +758,6 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } -uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { - return llvm::cast(unwrap(attr)).getValues()[pos]; -} float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } From 0fb48b4465821e6068062f8281d9739e1a1c53a1 Mon Sep 17 00:00:00 2001 From: Fabian Tschopp Date: Wed, 29 Jan 2025 00:56:00 +0100 Subject: [PATCH 838/915] [MLIR] Fix thread safety of the deleter in PyDenseResourceElementsAttribute (#124832) In general, `PyDenseResourceElementsAttribute` can get deleted at any time and any thread, where unlike the `getFromBuffer` call, the Python interpreter may not be initialized and the GIL may not be held. This PR fixes segfaults caused by `PyBuffer_Release` when the GIL is not being held by the thread calling the deleter. --- mlir/lib/Bindings/Python/IRAttributes.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 7bc21a31c..142b6eca1 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1468,7 +1468,10 @@ class PyDenseResourceElementsAttribute // The userData is a Py_buffer* that the deleter owns. auto deleter = [](void *userData, const void *data, size_t size, size_t align) { + if (!Py_IsInitialized()) + Py_Initialize(); Py_buffer *ownedView = static_cast(userData); + nb::gil_scoped_acquire gil; PyBuffer_Release(ownedView); delete ownedView; }; From bba6858816bf5c7cf41438bcf5e511d0935a1c87 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 29 Jan 2025 09:14:37 +0100 Subject: [PATCH 839/915] Reapply "[mlir][python] allow DenseIntElementsAttr for index type (#118947)" (#124804) This reapplies #118947 and adapts to nanobind. --- mlir/include/mlir-c/BuiltinAttributes.h | 2 ++ mlir/lib/Bindings/Python/IRAttributes.cpp | 10 ++++++++-- mlir/lib/CAPI/IR/BuiltinAttributes.cpp | 3 +++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinAttributes.h b/mlir/include/mlir-c/BuiltinAttributes.h index 7c8c84e55..1d0edf9ea 100644 --- a/mlir/include/mlir-c/BuiltinAttributes.h +++ b/mlir/include/mlir-c/BuiltinAttributes.h @@ -556,6 +556,8 @@ MLIR_CAPI_EXPORTED int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos); +MLIR_CAPI_EXPORTED uint64_t +mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos); MLIR_CAPI_EXPORTED double diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 142b6eca1..12725a0ed 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1372,13 +1372,19 @@ class PyDenseIntElementsAttribute MlirType type = mlirAttributeGetType(*this); type = mlirShapedTypeGetElementType(type); - assert(mlirTypeIsAInteger(type) && - "expected integer element type in dense int elements attribute"); + // Index type can also appear as a DenseIntElementsAttr and therefore can be + // casted to integer. + assert(mlirTypeIsAInteger(type) || + mlirTypeIsAIndex(type) && "expected integer/index element type in " + "dense int elements attribute"); // Dispatch element extraction to an appropriate C function based on the // elemental type of the attribute. nb::int_ is implicitly constructible // from any C++ integral type and handles bitwidth correctly. // TODO: consider caching the type properties in the constructor to avoid // querying them on each element access. + if (mlirTypeIsAIndex(type)) { + return nb::int_(mlirDenseElementsAttrGetIndexValue(*this, pos)); + } unsigned width = mlirIntegerTypeGetWidth(type); bool isUnsigned = mlirIntegerTypeIsUnsigned(type); if (isUnsigned) { diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 11d1ade55..8d57ab6b5 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -758,6 +758,9 @@ int64_t mlirDenseElementsAttrGetInt64Value(MlirAttribute attr, intptr_t pos) { uint64_t mlirDenseElementsAttrGetUInt64Value(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } +uint64_t mlirDenseElementsAttrGetIndexValue(MlirAttribute attr, intptr_t pos) { + return llvm::cast(unwrap(attr)).getValues()[pos]; +} float mlirDenseElementsAttrGetFloatValue(MlirAttribute attr, intptr_t pos) { return llvm::cast(unwrap(attr)).getValues()[pos]; } From bc89d63ba62e01885cb8abc36f455833a3020c64 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 30 Jan 2025 11:43:22 -0600 Subject: [PATCH 840/915] [mlir][llvmir] expose Type(To/From)LLVMIRTranslator C API (#124864) --- mlir/include/mlir-c/Dialect/LLVM.h | 3 ++ mlir/include/mlir-c/Target/LLVMIR.h | 43 +++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 4 +++ mlir/lib/CAPI/Target/LLVMIR.cpp | 49 +++++++++++++++++++++++++++-- 4 files changed, 96 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 26c414075..65b14254e 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -52,6 +52,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos); +/// Returns the return type of the function type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h index effa74b90..b5f948961 100644 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -16,6 +16,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "llvm-c/Core.h" #include "llvm-c/Support.h" #ifdef __cplusplus @@ -32,6 +33,48 @@ extern "C" { MLIR_CAPI_EXPORTED LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); +struct MlirTypeFromLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator; + +/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx); + +/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeFromLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator); + +/// Translates the given LLVM IR type to the MLIR LLVM dialect. +MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType); + +struct MlirTypeToLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator; + +/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx); + +/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeToLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator); + +/// Translates the given MLIR LLVM dialect to the LLVM IR type. +MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType( + MlirTypeToLLVMIRTranslator translator, MlirType mlirType); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index da450dd3f..69c804b76 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -65,6 +65,10 @@ MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { .getParamType(static_cast(pos))); } +MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getReturnType()); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp index dc798372b..5e2bba8be 100644 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -8,16 +8,15 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Target/LLVMIR.h" -#include "llvm-c/Support.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include +#include "llvm/IR/Type.h" #include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" using namespace mlir; @@ -34,3 +33,47 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, return moduleRef; } + +DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, + mlir::LLVM::TypeFromLLVMIRTranslator); + +MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { + MLIRContext *context = unwrap(ctx); + auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeFromLLVMIRTranslatorDestroy( + MlirTypeFromLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) { + LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator); + mlir::Type type = translator_->translateType(llvm::unwrap(llvmType)); + return wrap(type); +} + +DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, + mlir::LLVM::TypeToLLVMIRTranslator); + +MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { + llvm::LLVMContext *context = llvm::unwrap(ctx); + auto *translator = new LLVM::TypeToLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +LLVMTypeRef +mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator, + MlirType mlirType) { + LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator); + llvm::Type *type = translator_->translateType(unwrap(mlirType)); + return llvm::wrap(type); +} From 24cf42c6ab0da6178fe7b1cdf66ef43e8f4c1adf Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Thu, 30 Jan 2025 09:50:33 -0800 Subject: [PATCH 841/915] [mlir] Fix warnings This patch fixes: mlir/lib/CAPI/Target/LLVMIR.cpp:38:63: error: extra ';' outside of a function is incompatible with C++98 [-Werror,-Wc++98-compat-extra-semi] mlir/lib/CAPI/Target/LLVMIR.cpp:60:61: error: extra ';' outside of a function is incompatible with C++98 [-Werror,-Wc++98-compat-extra-semi] --- mlir/lib/CAPI/Target/LLVMIR.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp index 5e2bba8be..1c1912aec 100644 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -35,7 +35,7 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, } DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, - mlir::LLVM::TypeFromLLVMIRTranslator); + mlir::LLVM::TypeFromLLVMIRTranslator) MlirTypeFromLLVMIRTranslator mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { @@ -57,7 +57,7 @@ MlirType mlirTypeFromLLVMIRTranslatorTranslateType( } DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, - mlir::LLVM::TypeToLLVMIRTranslator); + mlir::LLVM::TypeToLLVMIRTranslator) MlirTypeToLLVMIRTranslator mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { From 7cbf770be2316642d56d9157403a291dfcd1a72b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 30 Jan 2025 14:09:14 -0500 Subject: [PATCH 842/915] Revert "[mlir][llvmir] expose Type(To/From)LLVMIRTranslator C API (#124864)" This reverts commit bc89d63ba62e01885cb8abc36f455833a3020c64. Revert "[mlir] Fix warnings" This reverts commit 24cf42c6ab0da6178fe7b1cdf66ef43e8f4c1adf. --- mlir/include/mlir-c/Dialect/LLVM.h | 3 -- mlir/include/mlir-c/Target/LLVMIR.h | 43 ------------------------- mlir/lib/CAPI/Dialect/LLVM.cpp | 4 --- mlir/lib/CAPI/Target/LLVMIR.cpp | 49 ++--------------------------- 4 files changed, 3 insertions(+), 96 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 65b14254e..26c414075 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -52,9 +52,6 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos); -/// Returns the return type of the function type. -MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); - /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h index b5f948961..effa74b90 100644 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -16,7 +16,6 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" -#include "llvm-c/Core.h" #include "llvm-c/Support.h" #ifdef __cplusplus @@ -33,48 +32,6 @@ extern "C" { MLIR_CAPI_EXPORTED LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); -struct MlirTypeFromLLVMIRTranslator { - void *ptr; -}; - -typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator; - -/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the -/// caller. -MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator -mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx); - -/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. -/// It is the responsibility of the user to only pass an -/// LLVM::TypeFromLLVMIRTranslator class. -MLIR_CAPI_EXPORTED void -mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator); - -/// Translates the given LLVM IR type to the MLIR LLVM dialect. -MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType( - MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType); - -struct MlirTypeToLLVMIRTranslator { - void *ptr; -}; - -typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator; - -/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the -/// caller. -MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator -mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx); - -/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. -/// It is the responsibility of the user to only pass an -/// LLVM::TypeToLLVMIRTranslator class. -MLIR_CAPI_EXPORTED void -mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator); - -/// Translates the given MLIR LLVM dialect to the LLVM IR type. -MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType( - MlirTypeToLLVMIRTranslator translator, MlirType mlirType); - #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index 69c804b76..da450dd3f 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -65,10 +65,6 @@ MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { .getParamType(static_cast(pos))); } -MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { - return wrap(llvm::cast(unwrap(type)).getReturnType()); -} - bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp index 1c1912aec..dc798372b 100644 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -8,15 +8,16 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Target/LLVMIR.h" +#include "llvm-c/Support.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Type.h" +#include #include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" -#include "mlir/Target/LLVMIR/TypeFromLLVM.h" using namespace mlir; @@ -33,47 +34,3 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, return moduleRef; } - -DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, - mlir::LLVM::TypeFromLLVMIRTranslator) - -MlirTypeFromLLVMIRTranslator -mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { - MLIRContext *context = unwrap(ctx); - auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context); - return wrap(translator); -} - -void mlirTypeFromLLVMIRTranslatorDestroy( - MlirTypeFromLLVMIRTranslator translator) { - delete static_cast(unwrap(translator)); -} - -MlirType mlirTypeFromLLVMIRTranslatorTranslateType( - MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) { - LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator); - mlir::Type type = translator_->translateType(llvm::unwrap(llvmType)); - return wrap(type); -} - -DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, - mlir::LLVM::TypeToLLVMIRTranslator) - -MlirTypeToLLVMIRTranslator -mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { - llvm::LLVMContext *context = llvm::unwrap(ctx); - auto *translator = new LLVM::TypeToLLVMIRTranslator(*context); - return wrap(translator); -} - -void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) { - delete static_cast(unwrap(translator)); -} - -LLVMTypeRef -mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator, - MlirType mlirType) { - LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator); - llvm::Type *type = translator_->translateType(unwrap(mlirType)); - return llvm::wrap(type); -} From 570d15f0fd35866c8f414a1976399a24d3548683 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 31 Jan 2025 09:40:56 -0600 Subject: [PATCH 843/915] [mlir][llvmir][reland] expose Type(To/From)LLVMIRTranslator C API (#125110) --- mlir/include/mlir-c/Dialect/LLVM.h | 3 ++ mlir/include/mlir-c/Target/LLVMIR.h | 43 +++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/LLVM.cpp | 4 +++ mlir/lib/CAPI/Target/CMakeLists.txt | 1 + mlir/lib/CAPI/Target/LLVMIR.cpp | 49 +++++++++++++++++++++++++++-- 5 files changed, 97 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/LLVM.h b/mlir/include/mlir-c/Dialect/LLVM.h index 26c414075..65b14254e 100644 --- a/mlir/include/mlir-c/Dialect/LLVM.h +++ b/mlir/include/mlir-c/Dialect/LLVM.h @@ -52,6 +52,9 @@ MLIR_CAPI_EXPORTED intptr_t mlirLLVMFunctionTypeGetNumInputs(MlirType type); MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos); +/// Returns the return type of the function type. +MLIR_CAPI_EXPORTED MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type); + /// Returns `true` if the type is an LLVM dialect struct type. MLIR_CAPI_EXPORTED bool mlirTypeIsALLVMStructType(MlirType type); diff --git a/mlir/include/mlir-c/Target/LLVMIR.h b/mlir/include/mlir-c/Target/LLVMIR.h index effa74b90..b5f948961 100644 --- a/mlir/include/mlir-c/Target/LLVMIR.h +++ b/mlir/include/mlir-c/Target/LLVMIR.h @@ -16,6 +16,7 @@ #include "mlir-c/IR.h" #include "mlir-c/Support.h" +#include "llvm-c/Core.h" #include "llvm-c/Support.h" #ifdef __cplusplus @@ -32,6 +33,48 @@ extern "C" { MLIR_CAPI_EXPORTED LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, LLVMContextRef context); +struct MlirTypeFromLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeFromLLVMIRTranslator MlirTypeFromLLVMIRTranslator; + +/// Create an LLVM::TypeFromLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx); + +/// Takes an LLVM::TypeFromLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeFromLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeFromLLVMIRTranslatorDestroy(MlirTypeFromLLVMIRTranslator translator); + +/// Translates the given LLVM IR type to the MLIR LLVM dialect. +MLIR_CAPI_EXPORTED MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType); + +struct MlirTypeToLLVMIRTranslator { + void *ptr; +}; + +typedef struct MlirTypeToLLVMIRTranslator MlirTypeToLLVMIRTranslator; + +/// Create an LLVM::TypeToLLVMIRTranslator and transfer ownership to the +/// caller. +MLIR_CAPI_EXPORTED MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx); + +/// Takes an LLVM::TypeToLLVMIRTranslator owned by the caller and destroys it. +/// It is the responsibility of the user to only pass an +/// LLVM::TypeToLLVMIRTranslator class. +MLIR_CAPI_EXPORTED void +mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator); + +/// Translates the given MLIR LLVM dialect to the LLVM IR type. +MLIR_CAPI_EXPORTED LLVMTypeRef mlirTypeToLLVMIRTranslatorTranslateType( + MlirTypeToLLVMIRTranslator translator, MlirType mlirType); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/LLVM.cpp b/mlir/lib/CAPI/Dialect/LLVM.cpp index da450dd3f..69c804b76 100644 --- a/mlir/lib/CAPI/Dialect/LLVM.cpp +++ b/mlir/lib/CAPI/Dialect/LLVM.cpp @@ -65,6 +65,10 @@ MlirType mlirLLVMFunctionTypeGetInput(MlirType type, intptr_t pos) { .getParamType(static_cast(pos))); } +MlirType mlirLLVMFunctionTypeGetReturnType(MlirType type) { + return wrap(llvm::cast(unwrap(type)).getReturnType()); +} + bool mlirTypeIsALLVMStructType(MlirType type) { return isa(unwrap(type)); } diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt index ce86fd3de..ea617da72 100644 --- a/mlir/lib/CAPI/Target/CMakeLists.txt +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -8,5 +8,6 @@ add_mlir_upstream_c_api_library(MLIRCAPITarget MLIRToLLVMIRTranslationRegistration MLIRCAPIIR MLIRLLVMToLLVMIRTranslation + MLIRLLVMIRToLLVMTranslation MLIRSupport ) diff --git a/mlir/lib/CAPI/Target/LLVMIR.cpp b/mlir/lib/CAPI/Target/LLVMIR.cpp index dc798372b..1c1912aec 100644 --- a/mlir/lib/CAPI/Target/LLVMIR.cpp +++ b/mlir/lib/CAPI/Target/LLVMIR.cpp @@ -8,16 +8,15 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Target/LLVMIR.h" -#include "llvm-c/Support.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include +#include "llvm/IR/Type.h" #include "mlir/CAPI/IR.h" -#include "mlir/CAPI/Support.h" #include "mlir/CAPI/Wrap.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" +#include "mlir/Target/LLVMIR/TypeFromLLVM.h" using namespace mlir; @@ -34,3 +33,47 @@ LLVMModuleRef mlirTranslateModuleToLLVMIR(MlirOperation module, return moduleRef; } + +DEFINE_C_API_PTR_METHODS(MlirTypeFromLLVMIRTranslator, + mlir::LLVM::TypeFromLLVMIRTranslator) + +MlirTypeFromLLVMIRTranslator +mlirTypeFromLLVMIRTranslatorCreate(MlirContext ctx) { + MLIRContext *context = unwrap(ctx); + auto *translator = new LLVM::TypeFromLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeFromLLVMIRTranslatorDestroy( + MlirTypeFromLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +MlirType mlirTypeFromLLVMIRTranslatorTranslateType( + MlirTypeFromLLVMIRTranslator translator, LLVMTypeRef llvmType) { + LLVM::TypeFromLLVMIRTranslator *translator_ = unwrap(translator); + mlir::Type type = translator_->translateType(llvm::unwrap(llvmType)); + return wrap(type); +} + +DEFINE_C_API_PTR_METHODS(MlirTypeToLLVMIRTranslator, + mlir::LLVM::TypeToLLVMIRTranslator) + +MlirTypeToLLVMIRTranslator +mlirTypeToLLVMIRTranslatorCreate(LLVMContextRef ctx) { + llvm::LLVMContext *context = llvm::unwrap(ctx); + auto *translator = new LLVM::TypeToLLVMIRTranslator(*context); + return wrap(translator); +} + +void mlirTypeToLLVMIRTranslatorDestroy(MlirTypeToLLVMIRTranslator translator) { + delete static_cast(unwrap(translator)); +} + +LLVMTypeRef +mlirTypeToLLVMIRTranslatorTranslateType(MlirTypeToLLVMIRTranslator translator, + MlirType mlirType) { + LLVM::TypeToLLVMIRTranslator *translator_ = unwrap(translator); + llvm::Type *type = translator_->translateType(unwrap(mlirType)); + return llvm::wrap(type); +} From 6ffd509500947cbf849f965d8809f44c0a2e86a9 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Wed, 5 Feb 2025 11:48:11 -0800 Subject: [PATCH 844/915] [mlir] Python: Extend print large elements limit to resources (#125738) If the large element limit is specified, large elements are hidden from the asm but large resources are not. This change extends the large elements limit to apply to printed resources as well. --- mlir/lib/Bindings/Python/IRCore.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 8e351cb22..47a85c2a4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1296,8 +1296,10 @@ void PyOperationBase::print(std::optional largeElementsLimit, fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) + if (largeElementsLimit) { mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); + mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit); + } if (enableDebugInfo) mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/prettyDebugInfo); From 89d3da814e0e06f7d4d0c45cbd769d259a1462b3 Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Wed, 5 Feb 2025 11:48:37 -0800 Subject: [PATCH 845/915] [mlir] Python: Parse ModuleOp from file path (#125736) For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/include/mlir/Bindings/Python/Nanobind.h | 1 + mlir/lib/Bindings/Python/IRCore.cpp | 16 +++++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 10 ++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 ++- 5 files changed, 32 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7d2fd89e8..14ccae650 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Parses a module from file and transfers ownership to the caller. +MLIR_CAPI_EXPORTED MlirModule +mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index ca942c83d..bc8bddf4c 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -23,6 +23,7 @@ #endif #include #include +#include #include #include #include diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 47a85c2a4..2e4b6d1ce 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include @@ -299,7 +300,7 @@ struct PyAttrBuilderMap { return *builder; } static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } @@ -3049,6 +3050,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) + .def_static( + "parse", + [](const std::filesystem::path &path, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParseFromFile( + context->get(), toMlirStringRef(path.string())); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("asm"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) .def_static( "create", [](DefaultingPyLocation loc) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index f27af0ca9..999e8cbda 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { return MlirModule{owning.release().getOperation()}; } +MlirModule mlirModuleCreateParseFromFile(MlirContext context, + MlirStringRef fileName) { + OwningOpRef owning = + parseSourceFile(unwrap(fileName), unwrap(context)); + if (!owning) + return MlirModule{nullptr}; + return MlirModule{owning.release().getOperation()}; +} + MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index fb7efb8cd..096b87b36 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -46,6 +46,7 @@ import abc import collections from collections.abc import Callable, Sequence import io +from pathlib import Path from typing import Any, ClassVar, TypeVar, overload __all__ = [ @@ -2123,7 +2124,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str | bytes, context: Context | None = None) -> Module: + def parse(asm: str | bytes | Path, context: Context | None = None) -> Module: """ Parses a module's assembly format from a string. From 08254432f7962cdc4271747d2789d8a16caa0c7f Mon Sep 17 00:00:00 2001 From: Md Asghar Ahmad Shahid Date: Fri, 7 Feb 2025 00:38:50 +0530 Subject: [PATCH 846/915] [MLIR][Linalg] Introduce broadcast/transpose semantic to batch_matmul (#122275) Goals: 1. To add syntax and semantic to 'batch_matmul' without changing any of the existing syntax expectations for current usage. batch_matmul is still just batch_matmul. 2. Move the definition of batch_matmul from linalg OpDsl to tablegen ODS infra. Scope of this patch: To expose broadcast and transpose semantics on the 'batch_matmul'. The broadcast and transpose semantic are as follows: By default, 'linalg.batch_matmul' behavior will remain as is. Broadcast and Transpose semantics can be applied by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so the list must include all the maps if specified. Example Transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d0, d3, d1)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<2x5x3xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<5xf32>,memref<2x5x7xf32>) outs (%arg2: memref<2x3x7xf32>) ``` Example Broadcast and transpose: ``` linalg.batch_matmul indexing_maps = [ affine_map< (d0, d1, d2, d3) -> (d1, d3)>, //broadcast affine_map< (d0, d1, d2, d3) -> (d0, d2, d3)>, //transpose affine_map< (d0, d1, d2, d3) -> (d0, d1, d2)> ] ins (%arg0, %arg1: memref<3x5xf32>, memref<2x7x5xf32>) outs (%arg2: memref<2x3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319 --- .../linalg/opdsl/ops/core_named_ops.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index c95cd5eec..040663c88 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -484,24 +484,6 @@ def batch_mmt4d( ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) -@linalg_structured_op -def batch_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True), -): - """Performs a batched matrix multiplication of two 3D inputs. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n] - ) - - @linalg_structured_op def batch_matmul_transpose_a( A=TensorDef(T1, Batch, S.K, S.M), From 716d836724c38d5a788a717f2ab46a408949854e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 6 Feb 2025 18:33:12 -0500 Subject: [PATCH 847/915] [mlir] feat: add `mlirFuncSetResultAttr` (#125972) cc @ftynse @wsmoses --- mlir/include/mlir-c/Dialect/Func.h | 4 ++++ mlir/lib/CAPI/Dialect/Func.cpp | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Func.h b/mlir/include/mlir-c/Dialect/Func.h index 1df759f0e..001f915af 100644 --- a/mlir/include/mlir-c/Dialect/Func.h +++ b/mlir/include/mlir-c/Dialect/Func.h @@ -35,6 +35,10 @@ MLIR_CAPI_EXPORTED void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, MlirStringRef name, MlirAttribute attr); +MLIR_CAPI_EXPORTED void mlirFuncSetResultAttr(MlirOperation op, intptr_t pos, + MlirStringRef name, + MlirAttribute attr); + #ifdef __cplusplus } #endif diff --git a/mlir/lib/CAPI/Dialect/Func.cpp b/mlir/lib/CAPI/Dialect/Func.cpp index 942e090fd..8265b61b9 100644 --- a/mlir/lib/CAPI/Dialect/Func.cpp +++ b/mlir/lib/CAPI/Dialect/Func.cpp @@ -19,3 +19,9 @@ void mlirFuncSetArgAttr(MlirOperation op, intptr_t pos, MlirStringRef name, llvm::cast(unwrap(op)) .setArgAttr(pos, unwrap(name), unwrap(attr)); } + +void mlirFuncSetResultAttr(MlirOperation op, intptr_t pos, MlirStringRef name, + MlirAttribute attr) { + llvm::cast(unwrap(op)) + .setResultAttr(pos, unwrap(name), unwrap(attr)); +} From 06ef50aabea23da3cb079c19c3abfa08fa5b3003 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Mon, 10 Feb 2025 09:09:58 +0100 Subject: [PATCH 848/915] Revert "[mlir] Python: Parse ModuleOp from file path" (#126482) Reverts llvm/llvm-project#125736 The gcc7 Bot is broken at the moment. --- mlir/include/mlir-c/IR.h | 4 ---- mlir/include/mlir/Bindings/Python/Nanobind.h | 1 - mlir/lib/Bindings/Python/IRCore.cpp | 16 +--------------- mlir/lib/CAPI/IR/IR.cpp | 10 ---------- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 +-- 5 files changed, 2 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 14ccae650..7d2fd89e8 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -309,10 +309,6 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); -/// Parses a module from file and transfers ownership to the caller. -MLIR_CAPI_EXPORTED MlirModule -mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); - /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); diff --git a/mlir/include/mlir/Bindings/Python/Nanobind.h b/mlir/include/mlir/Bindings/Python/Nanobind.h index bc8bddf4c..ca942c83d 100644 --- a/mlir/include/mlir/Bindings/Python/Nanobind.h +++ b/mlir/include/mlir/Bindings/Python/Nanobind.h @@ -23,7 +23,6 @@ #endif #include #include -#include #include #include #include diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2e4b6d1ce..47a85c2a4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// -#include #include #include @@ -300,7 +299,7 @@ struct PyAttrBuilderMap { return *builder; } static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } @@ -3050,19 +3049,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) - .def_static( - "parse", - [](const std::filesystem::path &path, - DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirModule module = mlirModuleCreateParseFromFile( - context->get(), toMlirStringRef(path.string())); - if (mlirModuleIsNull(module)) - throw MLIRError("Unable to parse module assembly", errors.take()); - return PyModule::forModule(module).releaseObject(); - }, - nb::arg("asm"), nb::arg("context").none() = nb::none(), - kModuleParseDocstring) .def_static( "create", [](DefaultingPyLocation loc) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 999e8cbda..f27af0ca9 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,7 +22,6 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" -#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -329,15 +328,6 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { return MlirModule{owning.release().getOperation()}; } -MlirModule mlirModuleCreateParseFromFile(MlirContext context, - MlirStringRef fileName) { - OwningOpRef owning = - parseSourceFile(unwrap(fileName), unwrap(context)); - if (!owning) - return MlirModule{nullptr}; - return MlirModule{owning.release().getOperation()}; -} - MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 096b87b36..fb7efb8cd 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -46,7 +46,6 @@ import abc import collections from collections.abc import Callable, Sequence import io -from pathlib import Path from typing import Any, ClassVar, TypeVar, overload __all__ = [ @@ -2124,7 +2123,7 @@ class Module: Creates an empty module """ @staticmethod - def parse(asm: str | bytes | Path, context: Context | None = None) -> Module: + def parse(asm: str | bytes, context: Context | None = None) -> Module: """ Parses a module's assembly format from a string. From 1a51cc258ad7bd414830f2386386a274433a3ff4 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 10 Feb 2025 13:05:13 +0100 Subject: [PATCH 849/915] [MLIR][Linalg] Expose linalg.matmul and linalg.contract via Python API (#126377) Now that linalg.matmul is in tablegen, "hand write" the Python wrapper that OpDSL used to derive. Similarly, add a Python wrapper for the new linalg.contract op. Required following misc. fixes: 1) make linalg.matmul's parsing and printing consistent w.r.t. whether indexing_maps occurs before or after operands, i.e. per the tests cases it comes _before_. 2) tablegen for linalg.contract did not state it accepted an optional cast attr. 3) In ODS's C++-generating code, expand partial support for `$_builder` access in `Attr::defaultValue` to full support. This enables access to the current `MlirContext` when constructing the default value (as is required when the default value consists of affine maps). --- mlir/python/mlir/dialects/linalg/__init__.py | 46 ++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 742262a9c..5cda4769d 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -147,3 +147,49 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) + + +def matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = MatmulOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + indexing_maps=indexing_maps, + cast=cast, + ) + fill_builtin_region(op.operation) + return op + + +def contract( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Sequence[AffineMapAttr], + cast: Optional[Union[TypeFn, Attribute]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) > 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = ContractOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + indexing_maps=indexing_maps, + cast=cast, + ) + fill_builtin_region(op.operation) + return op From 4ae0d0ba7039f4d9a09d0a0ed15d30ac7f1074cb Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Wed, 12 Feb 2025 14:02:41 -0800 Subject: [PATCH 850/915] [mlir] Python: Parse ModuleOp from file path (#126572) For extremely large models, it may be inefficient to load the model into memory in Python prior to passing it to the MLIR C APIs for deserialization. This change adds an API to parse a ModuleOp directly from a file path. Re-lands [89d3da8](https://github.com/llvm/llvm-project/commit/89d3da814e0e06f7d4d0c45cbd769d259a1462b3). --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 14 +++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 10 ++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 10 ++++++++++ 4 files changed, 37 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7d2fd89e8..14ccae650 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -309,6 +309,10 @@ MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateEmpty(MlirLocation location); MLIR_CAPI_EXPORTED MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module); +/// Parses a module from file and transfers ownership to the caller. +MLIR_CAPI_EXPORTED MlirModule +mlirModuleCreateParseFromFile(MlirContext context, MlirStringRef fileName); + /// Gets the context that a module was created with. MLIR_CAPI_EXPORTED MlirContext mlirModuleGetContext(MlirModule module); diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 47a85c2a4..827db5f3e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -299,7 +299,7 @@ struct PyAttrBuilderMap { return *builder; } static void dunderSetItemNamed(const std::string &attributeKind, - nb::callable func, bool replace) { + nb::callable func, bool replace) { PyGlobals::get().registerAttributeBuilder(attributeKind, std::move(func), replace); } @@ -3049,6 +3049,18 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("asm"), nb::arg("context").none() = nb::none(), kModuleParseDocstring) + .def_static( + "parseFile", + [](const std::string &path, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirModule module = mlirModuleCreateParseFromFile( + context->get(), toMlirStringRef(path)); + if (mlirModuleIsNull(module)) + throw MLIRError("Unable to parse module assembly", errors.take()); + return PyModule::forModule(module).releaseObject(); + }, + nb::arg("path"), nb::arg("context").none() = nb::none(), + kModuleParseDocstring) .def_static( "create", [](DefaultingPyLocation loc) { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index f27af0ca9..999e8cbda 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -22,6 +22,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/Operation.h" #include "mlir/IR/OperationSupport.h" +#include "mlir/IR/OwningOpRef.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" #include "mlir/IR/Verifier.h" @@ -328,6 +329,15 @@ MlirModule mlirModuleCreateParse(MlirContext context, MlirStringRef module) { return MlirModule{owning.release().getOperation()}; } +MlirModule mlirModuleCreateParseFromFile(MlirContext context, + MlirStringRef fileName) { + OwningOpRef owning = + parseSourceFile(unwrap(fileName), unwrap(context)); + if (!owning) + return MlirModule{nullptr}; + return MlirModule{owning.release().getOperation()}; +} + MlirContext mlirModuleGetContext(MlirModule module) { return wrap(unwrap(module).getContext()); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index fb7efb8cd..ab975a695 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -46,6 +46,7 @@ import abc import collections from collections.abc import Callable, Sequence import io +from pathlib import Path from typing import Any, ClassVar, TypeVar, overload __all__ = [ @@ -2129,6 +2130,15 @@ class Module: Returns a new MlirModule or raises an MLIRError if the parsing fails. + See also: https://mlir.llvm.org/docs/LangRef/ + """ + @staticmethod + def parseFile(path: str, context: Context | None = None) -> Module: + """ + Parses a module's assembly format from file. + + Returns a new MlirModule or raises an MLIRError if the parsing fails. + See also: https://mlir.llvm.org/docs/LangRef/ """ def _CAPICreate(self) -> Any: ... From c32c13fead725eadee3ee6ad072fa6a6f78a0bd3 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Thu, 13 Feb 2025 17:37:49 +0000 Subject: [PATCH 851/915] [mlir][index] Add CAPI (#127039) --- mlir/include/mlir-c/Dialect/Index.h | 24 ++++++++++++++++++++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 +++++++++ mlir/lib/CAPI/Dialect/Index.cpp | 13 +++++++++++++ 3 files changed, 46 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/Index.h create mode 100644 mlir/lib/CAPI/Dialect/Index.cpp diff --git a/mlir/include/mlir-c/Dialect/Index.h b/mlir/include/mlir-c/Dialect/Index.h new file mode 100644 index 000000000..3f05694ac --- /dev/null +++ b/mlir/include/mlir-c/Dialect/Index.h @@ -0,0 +1,24 @@ +//===-- mlir-c/Dialect/Index.h - C API for Index dialect ----------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_INDEX_H +#define MLIR_C_DIALECT_INDEX_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Index, index); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_INDEX_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index 5ad4bafed..ddd3d6629 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -81,6 +81,15 @@ add_mlir_upstream_c_api_library(MLIRCAPIGPU MLIRPass ) +add_mlir_upstream_c_api_library(MLIRCAPIIndex + Index.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRIndexDialect +) + add_mlir_upstream_c_api_library(MLIRCAPIIRDL IRDL.cpp diff --git a/mlir/lib/CAPI/Dialect/Index.cpp b/mlir/lib/CAPI/Dialect/Index.cpp new file mode 100644 index 000000000..845791436 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/Index.cpp @@ -0,0 +1,13 @@ +//===- Index.cpp - C Interface for Index dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/Index.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Index, index, mlir::index::IndexDialect) From c304a406f6fb02df6a4a8a3c0c46f30faae60f4d Mon Sep 17 00:00:00 2001 From: Edgar Date: Sat, 15 Feb 2025 12:21:20 +0100 Subject: [PATCH 852/915] [MLIR] Fix mlirExecutionEngineLookup throwing assert on lookup fail (#123924) Apparently trying to lookup a function pointer using the C api `mlirExecutionEngineLookup` will throw an assert instead of just returning a nullptr on builds with asserts. The docs itself says it returns a nullptr when no function is found so it should be sensible to not throw an assert in this case. --- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 507be9171..306cebd23 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -85,18 +85,20 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name, extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit, MlirStringRef name) { - auto expectedFPtr = unwrap(jit)->lookupPacked(unwrap(name)); - if (!expectedFPtr) + auto optionalFPtr = + llvm::expectedToOptional(unwrap(jit)->lookupPacked(unwrap(name))); + if (!optionalFPtr) return nullptr; - return reinterpret_cast(*expectedFPtr); + return reinterpret_cast(*optionalFPtr); } extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit, MlirStringRef name) { - auto expectedFPtr = unwrap(jit)->lookup(unwrap(name)); - if (!expectedFPtr) + auto optionalFPtr = + llvm::expectedToOptional(unwrap(jit)->lookup(unwrap(name))); + if (!optionalFPtr) return nullptr; - return reinterpret_cast(*expectedFPtr); + return *optionalFPtr; } extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, From bed39aad83d6e06b8a87b4fd0a8039600e76a2bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ingo=20M=C3=BCller?= Date: Wed, 19 Feb 2025 08:07:49 +0100 Subject: [PATCH 853/915] [mlir:python] Improve `mlir_(attribute|type|value)_subclass` for `nanobind`s `stubgen` (#127584) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes several improvements to the stubs that are created by `mlir_(attribute|type|value)_subclass`. First, the PR sets the `__module__` attribute of the classes generated by the nanobind adaptors for attributes, types, and values (via `mlir_(attribute|type|value)_subclass`). By default, the `__module__` property is set to `importlib._bootstrap`, which isn't where we want the new class to live. The new logic sets the property to the name of the module provided as `scope` instead. This also makes nanobind's `stubgen` generate stubs for those classes properly, which ignores classes whose `__module__` does not correspond to the module it is generating stubs for. This resolves #127518. Second, the PR overwrites the function signatures generated by `stubgen` to a format that uses the desired type names (e.g., `mlir.ir.Attribute` instead of `MlirAttribute`). Finally, the PR piggy-backs some minor doc and style improvements to `PythonAdaptors.h`. --------- Signed-off-by: Ingo Müller --- .../mlir/Bindings/Python/NanobindAdaptors.h | 29 ++++++++++++++----- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 517351cac..0608182f0 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -23,8 +23,10 @@ #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" +// clang-format off #include "mlir/Bindings/Python/Nanobind.h" #include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +// clang-format on #include "llvm/ADT/Twine.h" // Raw CAPI type casters need to be declared before use, so always include them @@ -349,6 +351,7 @@ class pure_subclass { thisClass = metaclass(derivedClassName, nanobind::make_tuple(superClass), attributes); scope.attr(derivedClassName) = thisClass; + thisClass.attr("__module__") = scope.attr("__name__"); } template @@ -434,7 +437,7 @@ class mlir_attribute_subclass : public pure_subclass { const nanobind::object &superCls, GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly + // Casting constructor. Note that it is hard, if not impossible, to properly // call chain to parent `__init__` in nanobind due to its special handling // for init functions that don't have a fully constructed self-reference, // which makes it impossible to forward it to `__init__` of a superclass. @@ -465,10 +468,13 @@ class mlir_attribute_subclass : public pure_subclass { thisClass.attr("__new__") = newCf; // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_attribute: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Attribute) -> bool"; def_staticmethod( "isinstance", [isaFunction](MlirAttribute other) { return isaFunction(other); }, - nanobind::arg("other_attribute")); + nanobind::arg("other_attribute"), nanobind::sig(kIsinstanceSig)); def("__repr__", [superCls, captureTypeName](nanobind::object self) { return nanobind::repr(superCls(self)) .attr("replace")(superCls.attr("__name__"), captureTypeName); @@ -512,7 +518,7 @@ class mlir_type_subclass : public pure_subclass { const nanobind::object &superCls, GetTypeIDFunctionTy getTypeIDFunction = nullptr) : pure_subclass(scope, typeClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly + // Casting constructor. Note that it is hard, if not impossible, to properly // call chain to parent `__init__` in nanobind due to its special handling // for init functions that don't have a fully constructed self-reference, // which makes it impossible to forward it to `__init__` of a superclass. @@ -542,13 +548,17 @@ class mlir_type_subclass : public pure_subclass { thisClass.attr("__new__") = newCf; // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_type: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Type) -> bool"; def_staticmethod( "isinstance", [isaFunction](MlirType other) { return isaFunction(other); }, - nanobind::arg("other_type")); + nanobind::arg("other_type"), nanobind::sig(kIsinstanceSig)); def("__repr__", [superCls, captureTypeName](nanobind::object self) { - return nanobind::repr(superCls(self)) - .attr("replace")(superCls.attr("__name__"), captureTypeName); + return nanobind::cast( + nanobind::repr(superCls(self)) + .attr("replace")(superCls.attr("__name__"), captureTypeName)); }); if (getTypeIDFunction) { // 'get_static_typeid' method. @@ -590,7 +600,7 @@ class mlir_value_subclass : public pure_subclass { IsAFunctionTy isaFunction, const nanobind::object &superCls) : pure_subclass(scope, valueClassName, superCls) { - // Casting constructor. Note that it hard, if not impossible, to properly + // Casting constructor. Note that it is hard, if not impossible, to properly // call chain to parent `__init__` in nanobind due to its special handling // for init functions that don't have a fully constructed self-reference, // which makes it impossible to forward it to `__init__` of a superclass. @@ -620,10 +630,13 @@ class mlir_value_subclass : public pure_subclass { thisClass.attr("__new__") = newCf; // 'isinstance' method. + static const char kIsinstanceSig[] = + "def isinstance(other_value: " MAKE_MLIR_PYTHON_QUALNAME( + "ir") ".Value) -> bool"; def_staticmethod( "isinstance", [isaFunction](MlirValue other) { return isaFunction(other); }, - nanobind::arg("other_value")); + nanobind::arg("other_value"), nanobind::sig(kIsinstanceSig)); } }; From c2dc4f503a8ddd05a3bf12557bd459e9439f5d3c Mon Sep 17 00:00:00 2001 From: Md Asghar Ahmad Shahid Date: Wed, 19 Feb 2025 19:45:02 +0530 Subject: [PATCH 854/915] [MLIR][Linalg] Introduce Python API for linalg.batch_matmul Ops. (#127614) As linalg.batch_matmul has been moved into tablegen from OpDSL, its derived python wrapper no longer exist.This patch adds the required python wrapper. Also refactors the BatchmatmulOp printer to make it consistent with its parser. --- mlir/python/mlir/dialects/linalg/__init__.py | 41 ++++++++++++-------- 1 file changed, 25 insertions(+), 16 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 5cda4769d..c5fbb833e 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -149,7 +149,8 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def matmul( +def create_op( + op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], indexing_maps: Optional[Sequence[AffineMapAttr]] = None, @@ -161,7 +162,7 @@ def matmul( init = _get_op_result_or_value(outs[0]) result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] - op = MatmulOp( + op = op_type( result_tensors=result_types, inputs=ins, outputs=[init], @@ -172,24 +173,32 @@ def matmul( return op +def matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + + +def batch_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return create_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + + def contract( *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - ins = [_get_op_result_or_value(input) for input in ins] - if len(outs) > 1: - raise ValueError(f"{outs=} must have length 1.") - init = _get_op_result_or_value(outs[0]) - result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] - - op = ContractOp( - result_tensors=result_types, - inputs=ins, - outputs=[init], - indexing_maps=indexing_maps, - cast=cast, + return create_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast ) - fill_builtin_region(op.operation) - return op From 1a42b4f3638669f49e36e0579efedc7cc3999c86 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 20 Feb 2025 10:02:36 -0600 Subject: [PATCH 855/915] [mlir][python] fix linalg.pack/unpack (#127729) This PR https://github.com/llvm/llvm-project/pull/123902 broke python bindings for `tensor.pack`/`unpack`. This PR fixes that. It also 1. adds convenience wrappers for pack/unpack 2. cleans up matmul-like ops in the linalg bindings 3. fixes linalg docs missing pack/unpack --- mlir/python/mlir/dialects/LinalgOps.td | 1 + mlir/python/mlir/dialects/linalg/__init__.py | 90 ++++++++++++++++++-- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/mlir/python/mlir/dialects/LinalgOps.td b/mlir/python/mlir/dialects/LinalgOps.td index b7658c85a..89fb3f219 100644 --- a/mlir/python/mlir/dialects/LinalgOps.td +++ b/mlir/python/mlir/dialects/LinalgOps.td @@ -11,5 +11,6 @@ include "mlir/Dialect/Linalg/IR/LinalgOps.td" include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.td" +include "mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td" #endif diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index c5fbb833e..63586a5bb 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -58,7 +58,11 @@ from .opdsl.ops.core_named_ops import * from ...ir import * -from .._ods_common import get_op_result_or_value as _get_op_result_or_value +from .._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_result_or_op_results as _get_op_result_or_op_results, + _dispatch_mixed_values, +) from ...extras.meta import region_op @@ -149,7 +153,7 @@ def __init__( generic = region_op(GenericOp_, terminator=YieldOp) -def create_op( +def _create_matmul_like_op( op_type, *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], @@ -179,7 +183,11 @@ def matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op(MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast) + return _get_op_result_or_op_results( + _create_matmul_like_op( + MatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) def batch_matmul( @@ -188,8 +196,10 @@ def batch_matmul( indexing_maps: Optional[Sequence[AffineMapAttr]] = None, cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) ) @@ -199,6 +209,72 @@ def contract( indexing_maps: Sequence[AffineMapAttr], cast: Optional[Union[TypeFn, Attribute]] = None, ): - return create_op( - ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + return _get_op_result_or_op_results( + _create_matmul_like_op( + ContractOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + +def pack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + padding_value=None, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + PackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + padding_value=padding_value, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) + ) + + +def unpack( + source, + dest, + inner_dims_pos, + inner_tiles, + *, + outer_dims_perm=None, + loc=None, + ip=None, +) -> ir.Value: + ( + dynamic_inner_tiles, + # packed here means %1:2 packing (results packing) + _inner_tiles, + static_inner_tiles, + ) = _dispatch_mixed_values(inner_tiles) + + return _get_op_result_or_op_results( + UnPackOp( + source=source, + dest=dest, + inner_dims_pos=inner_dims_pos, + inner_tiles=dynamic_inner_tiles, + static_inner_tiles=static_inner_tiles, + outer_dims_perm=outer_dims_perm, + loc=loc, + ip=ip, + ) ) From 755b96e1201eee87ed7cdc57aaad30213a3bd92a Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Mon, 24 Feb 2025 17:51:49 -0800 Subject: [PATCH 856/915] [mlir] Python: write bytecode to a file path (#127118) The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary. --- mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++--- mlir/lib/Bindings/Python/NanobindUtils.h | 49 +++++++++++++++++++----- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 6 +-- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 827db5f3e..b13a429d4 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -6,12 +6,10 @@ // //===----------------------------------------------------------------------===// -#include -#include - #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Debug.h" #include "mlir-c/Diagnostics.h" @@ -19,9 +17,14 @@ #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. +#include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include namespace nb = nanobind; using namespace nb::literals; @@ -1329,11 +1332,11 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject, accum.getUserData()); } -void PyOperationBase::writeBytecode(const nb::object &fileObject, +void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject, std::optional bytecodeVersion) { PyOperation &operation = getOperation(); operation.checkValid(); - PyFileAccumulator accum(fileObject, /*binary=*/true); + PyFileAccumulator accum(fileOrStringObject, /*binary=*/true); if (!bytecodeVersion.has_value()) return mlirOperationWriteBytecode(operation, accum.getCallback(), diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index ee193cf9f..64ea4329f 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -13,8 +13,13 @@ #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include template <> struct std::iterator_traits { @@ -128,33 +133,59 @@ struct PyPrintAccumulator { } }; -/// Accumulates int a python file-like object, either writing text (default) -/// or binary. +/// Accumulates into a file, either writing text (default) +/// or binary. The file may be a Python file-like object or a path to a file. class PyFileAccumulator { public: - PyFileAccumulator(const nanobind::object &fileObject, bool binary) - : pyWriteFunction(fileObject.attr("write")), binary(binary) {} + PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary) + : binary(binary) { + std::string filePath; + if (nanobind::try_cast(fileOrStringObject, filePath)) { + std::error_code ec; + writeTarget.emplace(filePath, ec); + if (ec) { + throw nanobind::value_error( + (std::string("Unable to open file for writing: ") + ec.message()) + .c_str()); + } + } else { + writeTarget.emplace(fileOrStringObject.attr("write")); + } + } + + MlirStringCallback getCallback() { + return writeTarget.index() == 0 ? getPyWriteCallback() + : getOstreamCallback(); + } void *getUserData() { return this; } - MlirStringCallback getCallback() { +private: + MlirStringCallback getPyWriteCallback() { return [](MlirStringRef part, void *userData) { nanobind::gil_scoped_acquire acquire; PyFileAccumulator *accum = static_cast(userData); if (accum->binary) { // Note: Still has to copy and not avoidable with this API. nanobind::bytes pyBytes(part.data, part.length); - accum->pyWriteFunction(pyBytes); + std::get(accum->writeTarget)(pyBytes); } else { nanobind::str pyStr(part.data, part.length); // Decodes as UTF-8 by default. - accum->pyWriteFunction(pyStr); + std::get(accum->writeTarget)(pyStr); } }; } -private: - nanobind::object pyWriteFunction; + MlirStringCallback getOstreamCallback() { + return [](MlirStringRef part, void *userData) { + PyFileAccumulator *accum = static_cast(userData); + std::get(accum->writeTarget) + .write(part.data, part.length); + }; + } + + std::variant writeTarget; bool binary; }; diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index ab975a695..c93de2fe3 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -47,7 +47,7 @@ import collections from collections.abc import Callable, Sequence import io from pathlib import Path -from typing import Any, ClassVar, TypeVar, overload +from typing import Any, BinaryIO, ClassVar, TypeVar, overload __all__ = [ "AffineAddExpr", @@ -285,12 +285,12 @@ class _OperationBase: """ Verify the operation. Raises MLIRError if verification fails, and returns true otherwise. """ - def write_bytecode(self, file: Any, desired_version: int | None = None) -> None: + def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None: """ Write the bytecode form of the operation to a file like object. Args: - file: The file like object to write to. + file: The file like object or path to write to. desired_version: The version of bytecode to emit. Returns: The bytecode writer status. From 5157bd00cd8d4025bda2a6db6a2faf46222b8bb6 Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Tue, 4 Mar 2025 11:49:34 -0800 Subject: [PATCH 857/915] [mlir][py] Plumb OpPrintingFlags::printNameLocAsPrefix() through the C/Python APIs (#129607) --- mlir/include/mlir-c/IR.h | 4 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 16 ++++++++++++---- mlir/lib/Bindings/Python/IRModule.h | 7 ++++--- mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 4 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 14ccae650..d562da1f9 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -456,6 +456,10 @@ mlirOpPrintingFlagsEnableDebugInfo(MlirOpPrintingFlags flags, bool enable, MLIR_CAPI_EXPORTED void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags); +/// Print the name and location, if NamedLoc, as a prefix to the SSA ID. +MLIR_CAPI_EXPORTED void +mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags); + /// Use local scope when printing the operation. This allows for using the /// printer in a more localized and thread-safe setting, but may not /// necessarily be identical to what the IR will look like when dumping diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b13a429d4..12793f7dd 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1291,8 +1291,9 @@ void PyOperation::checkValid() const { void PyOperationBase::print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nb::object fileObject, - bool binary, bool skipRegions) { + bool useNameLocAsPrefix, bool assumeVerified, + nb::object fileObject, bool binary, + bool skipRegions) { PyOperation &operation = getOperation(); operation.checkValid(); if (fileObject.is_none()) @@ -1314,6 +1315,8 @@ void PyOperationBase::print(std::optional largeElementsLimit, mlirOpPrintingFlagsAssumeVerified(flags); if (skipRegions) mlirOpPrintingFlagsSkipRegions(flags); + if (useNameLocAsPrefix) + mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); PyFileAccumulator accum(fileObject, binary); mlirOperationPrintWithFlags(operation, flags, accum.getCallback(), @@ -1390,7 +1393,8 @@ nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, bool skipRegions) { + bool useNameLocAsPrefix, bool assumeVerified, + bool skipRegions) { nb::object fileObject; if (binary) { fileObject = nb::module_::import_("io").attr("BytesIO")(); @@ -1402,6 +1406,7 @@ nb::object PyOperationBase::getAsm(bool binary, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, /*useLocalScope=*/useLocalScope, + /*useNameLocAsPrefix=*/useNameLocAsPrefix, /*assumeVerified=*/assumeVerified, /*fileObject=*/fileObject, /*binary=*/binary, @@ -3195,6 +3200,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, /*useLocalScope=*/false, + /*useNameLocAsPrefix=*/false, /*assumeVerified=*/false, /*skipRegions=*/false); }, @@ -3206,7 +3212,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", nb::overload_cast, bool, bool, bool, bool, - bool, nb::object, bool, bool>( + bool, bool, nb::object, bool, bool>( &PyOperationBase::print), // Careful: Lots of arguments must match up with print method. nb::arg("large_elements_limit").none() = nb::none(), @@ -3214,6 +3220,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("file").none() = nb::none(), nb::arg("binary") = false, nb::arg("skip_regions") = false, kOperationPrintDocstring) @@ -3228,6 +3235,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false, nb::arg("assume_verified") = false, nb::arg("skip_regions") = false, kOperationGetAsmDocstring) .def("verify", &PyOperationBase::verify, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index dd6e7ef91..1ed6240a6 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -576,15 +576,16 @@ class PyOperationBase { /// Implements the bound 'print' method and helps with others. void print(std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, nanobind::object fileObject, bool binary, - bool skipRegions); + bool useNameLocAsPrefix, bool assumeVerified, + nanobind::object fileObject, bool binary, bool skipRegions); void print(PyAsmState &state, nanobind::object fileObject, bool binary); nanobind::object getAsm(bool binary, std::optional largeElementsLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, - bool assumeVerified, bool skipRegions); + bool useNameLocAsPrefix, bool assumeVerified, + bool skipRegions); // Implement the bound 'writeBytecode' method. void writeBytecode(const nanobind::object &fileObject, diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 999e8cbda..6cd9ba2ae 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -218,6 +218,10 @@ void mlirOpPrintingFlagsPrintGenericOpForm(MlirOpPrintingFlags flags) { unwrap(flags)->printGenericOpForm(); } +void mlirOpPrintingFlagsPrintNameLocAsPrefix(MlirOpPrintingFlags flags) { + unwrap(flags)->printNameLocAsPrefix(); +} + void mlirOpPrintingFlagsUseLocalScope(MlirOpPrintingFlags flags) { unwrap(flags)->useLocalScope(); } From fdf6e4fa2a19cdf308d904dabec09932b8dddf71 Mon Sep 17 00:00:00 2001 From: lonely eagle <2020382038@qq.com> Date: Fri, 7 Mar 2025 13:00:05 +0800 Subject: [PATCH 858/915] [mlir][nvgpu] separate ops, types, attribute definitions in NVGPU dialect. (#129846) It is hoped that the Ops, Types, and Attribute of the NVGPU dialect can be defined in separate files.If downstream projects extend NVGPU and define other Ops, the types and attributes will be used.This PR was raised to avoid including the definition of NVGPU Ops. --- mlir/python/mlir/dialects/NVGPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/NVGPUOps.td b/mlir/python/mlir/dialects/NVGPUOps.td index ae54822cd..cdf651901 100644 --- a/mlir/python/mlir/dialects/NVGPUOps.td +++ b/mlir/python/mlir/dialects/NVGPUOps.td @@ -9,6 +9,6 @@ #ifndef PYTHON_BINDINGS_NVGPU_OPS #define PYTHON_BINDINGS_NVGPU_OPS -include "mlir/Dialect/NVGPU/IR/NVGPU.td" +include "mlir/Dialect/NVGPU/IR/NVGPUOps.td" #endif From 131aa875a682f7883b23cdbcc10c430261a5ef5b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 10 Mar 2025 05:10:34 -0400 Subject: [PATCH 859/915] [mlir][CAPI][python] bind CallSiteLoc, FileLineColRange, FusedLoc, NameLoc (#129351) This PR extends the python bindings for CallSiteLoc, FileLineColRange, FusedLoc, NameLoc with field accessors. It also adds the missing `value.location` accessor. I also did some "spring cleaning" here (`cast` -> `dyn_cast`) after running into some of my own illegal casts. --- mlir/include/mlir-c/IR.h | 80 ++++++++++++ .../mlir/Bindings/Python/NanobindAdaptors.h | 10 ++ mlir/lib/Bindings/Python/IRCore.cpp | 47 +++++++- mlir/lib/CAPI/IR/IR.cpp | 114 +++++++++++++++++- 4 files changed, 239 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index d562da1f9..7fd6a41fb 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -261,15 +261,75 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationFileLineColRangeGet( MlirContext context, MlirStringRef filename, unsigned start_line, unsigned start_col, unsigned end_line, unsigned end_col); +/// Getter for filename of FileLineColRange. +MLIR_CAPI_EXPORTED MlirIdentifier +mlirLocationFileLineColRangeGetFilename(MlirLocation location); + +/// Getter for start_line of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetStartLine(MlirLocation location); + +/// Getter for start_column of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetStartColumn(MlirLocation location); + +/// Getter for end_line of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetEndLine(MlirLocation location); + +/// Getter for end_column of FileLineColRange. +MLIR_CAPI_EXPORTED int +mlirLocationFileLineColRangeGetEndColumn(MlirLocation location); + +/// TypeID Getter for FileLineColRange. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFileLineColRangeGetTypeID(void); + +/// Checks whether the given location is an FileLineColRange. +MLIR_CAPI_EXPORTED bool mlirLocationIsAFileLineColRange(MlirLocation location); + /// Creates a call site location with a callee and a caller. MLIR_CAPI_EXPORTED MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller); +/// Getter for callee of CallSite. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationCallSiteGetCallee(MlirLocation location); + +/// Getter for caller of CallSite. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationCallSiteGetCaller(MlirLocation location); + +/// TypeID Getter for CallSite. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationCallSiteGetTypeID(void); + +/// Checks whether the given location is an CallSite. +MLIR_CAPI_EXPORTED bool mlirLocationIsACallSite(MlirLocation location); + /// Creates a fused location with an array of locations and metadata. MLIR_CAPI_EXPORTED MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata); +/// Getter for number of locations fused together. +MLIR_CAPI_EXPORTED unsigned +mlirLocationFusedGetNumLocations(MlirLocation location); + +/// Getter for locations of Fused. Requires pre-allocated memory of +/// #fusedLocations X sizeof(MlirLocation). +MLIR_CAPI_EXPORTED void +mlirLocationFusedGetLocations(MlirLocation location, + MlirLocation *locationsCPtr); + +/// Getter for metadata of Fused. +MLIR_CAPI_EXPORTED MlirAttribute +mlirLocationFusedGetMetadata(MlirLocation location); + +/// TypeID Getter for Fused. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationFusedGetTypeID(void); + +/// Checks whether the given location is an Fused. +MLIR_CAPI_EXPORTED bool mlirLocationIsAFused(MlirLocation location); + /// Creates a name location owned by the given context. Providing null location /// for childLoc is allowed and if childLoc is null location, then the behavior /// is the same as having unknown child location. @@ -277,6 +337,20 @@ MLIR_CAPI_EXPORTED MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc); +/// Getter for name of Name. +MLIR_CAPI_EXPORTED MlirIdentifier +mlirLocationNameGetName(MlirLocation location); + +/// Getter for childLoc of Name. +MLIR_CAPI_EXPORTED MlirLocation +mlirLocationNameGetChildLoc(MlirLocation location); + +/// TypeID Getter for Name. +MLIR_CAPI_EXPORTED MlirTypeID mlirLocationNameGetTypeID(void); + +/// Checks whether the given location is an Name. +MLIR_CAPI_EXPORTED bool mlirLocationIsAName(MlirLocation location); + /// Creates a location with unknown position owned by the given context. MLIR_CAPI_EXPORTED MlirLocation mlirLocationUnknownGet(MlirContext context); @@ -978,6 +1052,12 @@ mlirValueReplaceAllUsesExcept(MlirValue of, MlirValue with, intptr_t numExceptions, MlirOperation *exceptions); +/// Gets the location of the value. +MLIR_CAPI_EXPORTED MlirLocation mlirValueGetLocation(MlirValue v); + +/// Gets the context that a value was created with. +MLIR_CAPI_EXPORTED MlirContext mlirValueGetContext(MlirValue v); + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 0608182f0..3646bf42e 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -321,6 +321,16 @@ struct type_caster { } }; +/// Casts MlirStringRef -> object. +template <> +struct type_caster { + NB_TYPE_CASTER(MlirStringRef, const_name("MlirStringRef")) + static handle from_cpp(MlirStringRef s, rv_policy, + cleanup_list *cleanup) noexcept { + return nanobind::str(s.data, s.length).release(); + } +}; + } // namespace detail } // namespace nanobind diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 12793f7dd..9fd061d1c 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2943,6 +2943,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("callee"), nb::arg("frames"), nb::arg("context").none() = nb::none(), kContextGetCallSiteLocationDocstring) + .def("is_a_callsite", mlirLocationIsACallSite) + .def_prop_ro("callee", mlirLocationCallSiteGetCallee) + .def_prop_ro("caller", mlirLocationCallSiteGetCaller) .def_static( "file", [](std::string filename, int line, int col, @@ -2967,6 +2970,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("filename"), nb::arg("start_line"), nb::arg("start_col"), nb::arg("end_line"), nb::arg("end_col"), nb::arg("context").none() = nb::none(), kContextGetFileRangeDocstring) + .def("is_a_file", mlirLocationIsAFileLineColRange) + .def_prop_ro("filename", + [](MlirLocation loc) { + return mlirIdentifierStr( + mlirLocationFileLineColRangeGetFilename(loc)); + }) + .def_prop_ro("start_line", mlirLocationFileLineColRangeGetStartLine) + .def_prop_ro("start_col", mlirLocationFileLineColRangeGetStartColumn) + .def_prop_ro("end_line", mlirLocationFileLineColRangeGetEndLine) + .def_prop_ro("end_col", mlirLocationFileLineColRangeGetEndColumn) .def_static( "fused", [](const std::vector &pyLocations, @@ -2984,6 +2997,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("locations"), nb::arg("metadata").none() = nb::none(), nb::arg("context").none() = nb::none(), kContextGetFusedLocationDocstring) + .def("is_a_fused", mlirLocationIsAFused) + .def_prop_ro("locations", + [](MlirLocation loc) { + unsigned numLocations = + mlirLocationFusedGetNumLocations(loc); + std::vector locations(numLocations); + if (numLocations) + mlirLocationFusedGetLocations(loc, locations.data()); + return locations; + }) .def_static( "name", [](std::string name, std::optional childLoc, @@ -2998,6 +3021,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("name"), nb::arg("childLoc").none() = nb::none(), nb::arg("context").none() = nb::none(), kContextGetNameLocationDocString) + .def("is_a_name", mlirLocationIsAName) + .def_prop_ro("name_str", + [](MlirLocation loc) { + return mlirIdentifierStr(mlirLocationNameGetName(loc)); + }) + .def_prop_ro("child_loc", mlirLocationNameGetChildLoc) .def_static( "from_attr", [](PyAttribute &attribute, DefaultingPyMlirContext context) { @@ -3148,9 +3177,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { auto &concreteOperation = self.getOperation(); concreteOperation.checkValid(); MlirOperation operation = concreteOperation.get(); - MlirStringRef name = - mlirIdentifierStr(mlirOperationGetName(operation)); - return nb::str(name.data, name.length); + return mlirIdentifierStr(mlirOperationGetName(operation)); }) .def_prop_ro("operands", [](PyOperationBase &self) { @@ -3738,8 +3765,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def_prop_ro( "name", [](PyNamedAttribute &self) { - return nb::str(mlirIdentifierStr(self.namedAttr.name).data, - mlirIdentifierStr(self.namedAttr.name).length); + return mlirIdentifierStr(self.namedAttr.name); }, "The name of the NamedAttribute binding") .def_prop_ro( @@ -3972,7 +3998,16 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("with"), nb::arg("exceptions"), kValueReplaceAllUsesExceptDocstring) .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, - [](PyValue &self) { return self.maybeDownCast(); }); + [](PyValue &self) { return self.maybeDownCast(); }) + .def_prop_ro( + "location", + [](MlirValue self) { + return PyLocation( + PyMlirContext::forContext(mlirValueGetContext(self)), + mlirValueGetLocation(self)); + }, + "Returns the source location the value"); + PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 6cd9ba2ae..378d7d739 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -259,7 +259,7 @@ MlirAttribute mlirLocationGetAttribute(MlirLocation location) { } MlirLocation mlirLocationFromAttribute(MlirAttribute attribute) { - return wrap(Location(llvm::cast(unwrap(attribute)))); + return wrap(Location(llvm::dyn_cast(unwrap(attribute)))); } MlirLocation mlirLocationFileLineColGet(MlirContext context, @@ -278,10 +278,64 @@ mlirLocationFileLineColRangeGet(MlirContext context, MlirStringRef filename, startLine, startCol, endLine, endCol))); } +MlirIdentifier mlirLocationFileLineColRangeGetFilename(MlirLocation location) { + return wrap(llvm::dyn_cast(unwrap(location)).getFilename()); +} + +int mlirLocationFileLineColRangeGetStartLine(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getStartLine(); + return -1; +} + +int mlirLocationFileLineColRangeGetStartColumn(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getStartColumn(); + return -1; +} + +int mlirLocationFileLineColRangeGetEndLine(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getEndLine(); + return -1; +} + +int mlirLocationFileLineColRangeGetEndColumn(MlirLocation location) { + if (auto loc = llvm::dyn_cast(unwrap(location))) + return loc.getEndColumn(); + return -1; +} + +MlirTypeID mlirLocationFileLineColRangeGetTypeID() { + return wrap(FileLineColRange::getTypeID()); +} + +bool mlirLocationIsAFileLineColRange(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationCallSiteGet(MlirLocation callee, MlirLocation caller) { return wrap(Location(CallSiteLoc::get(unwrap(callee), unwrap(caller)))); } +MlirLocation mlirLocationCallSiteGetCallee(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getCallee())); +} + +MlirLocation mlirLocationCallSiteGetCaller(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getCaller())); +} + +MlirTypeID mlirLocationCallSiteGetTypeID() { + return wrap(CallSiteLoc::getTypeID()); +} + +bool mlirLocationIsACallSite(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, MlirLocation const *locations, MlirAttribute metadata) { @@ -290,6 +344,30 @@ MlirLocation mlirLocationFusedGet(MlirContext ctx, intptr_t nLocations, return wrap(FusedLoc::get(unwrappedLocs, unwrap(metadata), unwrap(ctx))); } +unsigned mlirLocationFusedGetNumLocations(MlirLocation location) { + if (auto locationsArrRef = llvm::dyn_cast(unwrap(location))) + return locationsArrRef.getLocations().size(); + return 0; +} + +void mlirLocationFusedGetLocations(MlirLocation location, + MlirLocation *locationsCPtr) { + if (auto locationsArrRef = llvm::dyn_cast(unwrap(location))) { + for (auto [i, location] : llvm::enumerate(locationsArrRef.getLocations())) + locationsCPtr[i] = wrap(location); + } +} + +MlirAttribute mlirLocationFusedGetMetadata(MlirLocation location) { + return wrap(llvm::dyn_cast(unwrap(location)).getMetadata()); +} + +MlirTypeID mlirLocationFusedGetTypeID() { return wrap(FusedLoc::getTypeID()); } + +bool mlirLocationIsAFused(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, MlirLocation childLoc) { if (mlirLocationIsNull(childLoc)) @@ -299,6 +377,21 @@ MlirLocation mlirLocationNameGet(MlirContext context, MlirStringRef name, StringAttr::get(unwrap(context), unwrap(name)), unwrap(childLoc)))); } +MlirIdentifier mlirLocationNameGetName(MlirLocation location) { + return wrap((llvm::dyn_cast(unwrap(location)).getName())); +} + +MlirLocation mlirLocationNameGetChildLoc(MlirLocation location) { + return wrap( + Location(llvm::dyn_cast(unwrap(location)).getChildLoc())); +} + +MlirTypeID mlirLocationNameGetTypeID() { return wrap(NameLoc::getTypeID()); } + +bool mlirLocationIsAName(MlirLocation location) { + return isa(unwrap(location)); +} + MlirLocation mlirLocationUnknownGet(MlirContext context) { return wrap(Location(UnknownLoc::get(unwrap(context)))); } @@ -975,25 +1068,26 @@ bool mlirValueIsAOpResult(MlirValue value) { } MlirBlock mlirBlockArgumentGetOwner(MlirValue value) { - return wrap(llvm::cast(unwrap(value)).getOwner()); + return wrap(llvm::dyn_cast(unwrap(value)).getOwner()); } intptr_t mlirBlockArgumentGetArgNumber(MlirValue value) { return static_cast( - llvm::cast(unwrap(value)).getArgNumber()); + llvm::dyn_cast(unwrap(value)).getArgNumber()); } void mlirBlockArgumentSetType(MlirValue value, MlirType type) { - llvm::cast(unwrap(value)).setType(unwrap(type)); + if (auto blockArg = llvm::dyn_cast(unwrap(value))) + blockArg.setType(unwrap(type)); } MlirOperation mlirOpResultGetOwner(MlirValue value) { - return wrap(llvm::cast(unwrap(value)).getOwner()); + return wrap(llvm::dyn_cast(unwrap(value)).getOwner()); } intptr_t mlirOpResultGetResultNumber(MlirValue value) { return static_cast( - llvm::cast(unwrap(value)).getResultNumber()); + llvm::dyn_cast(unwrap(value)).getResultNumber()); } MlirType mlirValueGetType(MlirValue value) { @@ -1047,6 +1141,14 @@ void mlirValueReplaceAllUsesExcept(MlirValue oldValue, MlirValue newValue, oldValueCpp.replaceAllUsesExcept(newValueCpp, exceptionSet); } +MlirLocation mlirValueGetLocation(MlirValue v) { + return wrap(unwrap(v).getLoc()); +} + +MlirContext mlirValueGetContext(MlirValue v) { + return wrap(unwrap(v).getContext()); +} + //===----------------------------------------------------------------------===// // OpOperand API. //===----------------------------------------------------------------------===// From 28358265fc16157e8641b2b4108306da2a15e899 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 10 Mar 2025 11:19:23 +0100 Subject: [PATCH 860/915] [MLIR][py] Add PyThreadPool as wrapper around MlirLlvmThreadPool in MLIR python bindings (#130109) In some projects like JAX ir.Context are used with disabled multi-threading to avoid caching multiple threading pools: https://github.com/jax-ml/jax/blob/623865fe9538100d877ba9d36f788d0f95a11ed2/jax/_src/interpreters/mlir.py#L606-L611 However, when context has enabled multithreading it also uses locks on the StorageUniquers and this can be helpful to avoid data races in the multi-threaded execution (for example with free-threaded cpython, https://github.com/jax-ml/jax/issues/26272). With this PR user can enable the multi-threading: 1) enables additional locking and 2) set a shared threading pool such that cached contexts can have one global pool. --- mlir/include/mlir-c/IR.h | 9 +++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 26 ++++++++++++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 27 ++++++++++++++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 8 ++++++++ mlir/python/mlir/_mlir_libs/__init__.py | 16 +++++++++++++-- 5 files changed, 83 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 7fd6a41fb..1a8e8737f 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -162,6 +162,15 @@ MLIR_CAPI_EXPORTED bool mlirContextIsRegisteredOperation(MlirContext context, MLIR_CAPI_EXPORTED void mlirContextSetThreadPool(MlirContext context, MlirLlvmThreadPool threadPool); +/// Gets the number of threads of the thread pool of the context when +/// multithreading is enabled. Returns 1 if no multithreading. +MLIR_CAPI_EXPORTED unsigned mlirContextGetNumThreads(MlirContext context); + +/// Gets the thread pool of the context when enabled multithreading, otherwise +/// an assertion is raised. +MLIR_CAPI_EXPORTED MlirLlvmThreadPool +mlirContextGetThreadPool(MlirContext context); + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 9fd061d1c..78ba144ac 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2743,6 +2743,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // __init__.py will subclass it with site-specific functionality and set a // "Context" attribute on this module. //---------------------------------------------------------------------------- + + // Expose DefaultThreadPool to python + nb::class_(m, "ThreadPool") + .def("__init__", [](PyThreadPool &self) { new (&self) PyThreadPool(); }) + .def("get_max_concurrency", &PyThreadPool::getMaxConcurrency) + .def("_mlir_thread_pool_ptr", &PyThreadPool::_mlir_thread_pool_ptr); + nb::class_(m, "_BaseContext") .def("__init__", [](PyMlirContext &self) { @@ -2814,6 +2821,25 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirContextEnableMultithreading(self.get(), enable); }, nb::arg("enable")) + .def("set_thread_pool", + [](PyMlirContext &self, PyThreadPool &pool) { + // we should disable multi-threading first before setting + // new thread pool otherwise the assert in + // MLIRContext::setThreadPool will be raised. + mlirContextEnableMultithreading(self.get(), false); + mlirContextSetThreadPool(self.get(), pool.get()); + }) + .def("get_num_threads", + [](PyMlirContext &self) { + return mlirContextGetNumThreads(self.get()); + }) + .def("_mlir_thread_pool_ptr", + [](PyMlirContext &self) { + MlirLlvmThreadPool pool = mlirContextGetThreadPool(self.get()); + std::stringstream ss; + ss << pool.ptr; + return ss.str(); + }) .def( "is_registered_operation", [](PyMlirContext &self, std::string &name) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 1ed6240a6..9befcce72 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -11,6 +11,7 @@ #define MLIR_BINDINGS_PYTHON_IRMODULES_H #include +#include #include #include @@ -22,9 +23,10 @@ #include "mlir-c/IR.h" #include "mlir-c/IntegerSet.h" #include "mlir-c/Transforms.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/Support/ThreadPool.h" namespace mlir { namespace python { @@ -158,6 +160,29 @@ class PyThreadContextEntry { FrameKind frameKind; }; +/// Wrapper around MlirLlvmThreadPool +/// Python object owns the C++ thread pool +class PyThreadPool { +public: + PyThreadPool() { + ownedThreadPool = std::make_unique(); + } + PyThreadPool(const PyThreadPool &) = delete; + PyThreadPool(PyThreadPool &&) = delete; + + int getMaxConcurrency() const { return ownedThreadPool->getMaxConcurrency(); } + MlirLlvmThreadPool get() { return wrap(ownedThreadPool.get()); } + + std::string _mlir_thread_pool_ptr() const { + std::stringstream ss; + ss << ownedThreadPool.get(); + return ss.str(); + } + +private: + std::unique_ptr ownedThreadPool; +}; + /// Wrapper around MlirContext. using PyMlirContextRef = PyObjectRef; class PyMlirContext { diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 378d7d739..e0e386d55 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -114,6 +114,14 @@ void mlirContextSetThreadPool(MlirContext context, unwrap(context)->setThreadPool(*unwrap(threadPool)); } +unsigned mlirContextGetNumThreads(MlirContext context) { + return unwrap(context)->getNumThreads(); +} + +MlirLlvmThreadPool mlirContextGetThreadPool(MlirContext context) { + return wrap(&unwrap(context)->getThreadPool()); +} + //===----------------------------------------------------------------------===// // Dialect API. //===----------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/__init__.py b/mlir/python/mlir/_mlir_libs/__init__.py index d021dde05..083a9075f 100644 --- a/mlir/python/mlir/_mlir_libs/__init__.py +++ b/mlir/python/mlir/_mlir_libs/__init__.py @@ -148,13 +148,25 @@ def process_initializer_module(module_name): break class Context(ir._BaseContext): - def __init__(self, load_on_create_dialects=None, *args, **kwargs): + def __init__( + self, load_on_create_dialects=None, thread_pool=None, *args, **kwargs + ): super().__init__(*args, **kwargs) self.append_dialect_registry(get_dialect_registry()) for hook in post_init_hooks: hook(self) + if disable_multithreading and thread_pool is not None: + raise ValueError( + "Context constructor has given thread_pool argument, " + "but disable_multithreading flag is True. " + "Please, set thread_pool argument to None or " + "set disable_multithreading flag to False." + ) if not disable_multithreading: - self.enable_multithreading(True) + if thread_pool is None: + self.enable_multithreading(True) + else: + self.set_thread_pool(thread_pool) if load_on_create_dialects is not None: logger.debug( "Loading all dialects from load_on_create_dialects arg %r", From 2b13af71fb11fcfa277ad32d5dc5b0af6470729b Mon Sep 17 00:00:00 2001 From: Nikhil Kalra Date: Mon, 10 Mar 2025 15:59:47 -0700 Subject: [PATCH 861/915] [mlir] Better Python diagnostics (#128581) Updated the Python diagnostics handler to emit notes (in addition to errors) into the output stream so that users have more context as to where in the IR the error is occurring. --- .../mlir/Bindings/Python/Diagnostics.h | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/Diagnostics.h b/mlir/include/mlir/Bindings/Python/Diagnostics.h index ea80e14dd..167002d56 100644 --- a/mlir/include/mlir/Bindings/Python/Diagnostics.h +++ b/mlir/include/mlir/Bindings/Python/Diagnostics.h @@ -9,12 +9,13 @@ #ifndef MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H #define MLIR_BINDINGS_PYTHON_DIAGNOSTICS_H -#include -#include - #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" -#include "llvm/ADT/StringRef.h" +#include "llvm/Support/raw_ostream.h" + +#include +#include +#include namespace mlir { namespace python { @@ -24,33 +25,45 @@ namespace python { class CollectDiagnosticsToStringScope { public: explicit CollectDiagnosticsToStringScope(MlirContext ctx) : context(ctx) { - handlerID = mlirContextAttachDiagnosticHandler(ctx, &handler, &errorMessage, - /*deleteUserData=*/nullptr); + handlerID = + mlirContextAttachDiagnosticHandler(ctx, &handler, &messageStream, + /*deleteUserData=*/nullptr); } ~CollectDiagnosticsToStringScope() { - assert(errorMessage.empty() && "unchecked error message"); + assert(message.empty() && "unchecked error message"); mlirContextDetachDiagnosticHandler(context, handlerID); } - [[nodiscard]] std::string takeMessage() { return std::move(errorMessage); } + [[nodiscard]] std::string takeMessage() { + std::string newMessage; + std::swap(message, newMessage); + return newMessage; + } private: static MlirLogicalResult handler(MlirDiagnostic diag, void *data) { auto printer = +[](MlirStringRef message, void *data) { - *static_cast(data) += - llvm::StringRef(message.data, message.length); + *static_cast(data) + << std::string_view(message.data, message.length); }; MlirLocation loc = mlirDiagnosticGetLocation(diag); - *static_cast(data) += "at "; + *static_cast(data) << "at "; mlirLocationPrint(loc, printer, data); - *static_cast(data) += ": "; + *static_cast(data) << ": "; mlirDiagnosticPrint(diag, printer, data); + for (intptr_t i = 0; i < mlirDiagnosticGetNumNotes(diag); i++) { + *static_cast(data) << "\n"; + MlirDiagnostic note = mlirDiagnosticGetNote(diag, i); + handler(note, data); + } return mlirLogicalResultSuccess(); } MlirContext context; MlirDiagnosticHandlerID handlerID; - std::string errorMessage = ""; + + std::string message; + llvm::raw_string_ostream messageStream{message}; }; } // namespace python From 084572663dd81afddd92be07622d01601ae54d28 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 14 Mar 2025 11:10:42 -0400 Subject: [PATCH 862/915] [mlir][python] Small optimization to mlirApiObjectToCapsule. (#131160) Call nb::getattr(...) rather than using nb::hasattr() and .attr(). Saves a Python string allocation and a dictionary lookup when using a recent nanobind. Optimization only, no changes in behavior expected. --- mlir/include/mlir/Bindings/Python/NanobindAdaptors.h | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 3646bf42e..2dd35c097 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -46,14 +46,16 @@ namespace detail { static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { if (PyCapsule_CheckExact(apiObject.ptr())) return nanobind::borrow(apiObject); - if (!nanobind::hasattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR)) { + nanobind::object api = + nanobind::getattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR, nanobind::none()); + if (api.is_none()) { std::string repr = nanobind::cast(nanobind::repr(apiObject)); throw nanobind::type_error( (llvm::Twine("Expected an MLIR object (got ") + repr + ").") .str() .c_str()); } - return apiObject.attr(MLIR_PYTHON_CAPI_PTR_ATTR); + return api; } // Note: Currently all of the following support cast from nanobind::object to From febaf3d5717ca709198041d3973c3d443d52b27c Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sun, 16 Mar 2025 17:57:56 +0100 Subject: [PATCH 863/915] [mlir] Expose `AffineExpr.shift_dims/shift_symbols` through C and Python bindings (#131521) --- mlir/include/mlir-c/AffineExpr.h | 12 ++++++++++++ mlir/lib/Bindings/Python/IRAffine.cpp | 19 +++++++++++++++++++ mlir/lib/CAPI/IR/AffineExpr.cpp | 12 ++++++++++++ 3 files changed, 43 insertions(+) diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h index 14e951dde..ab768eb2e 100644 --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -92,6 +92,18 @@ MLIR_CAPI_EXPORTED bool mlirAffineExprIsFunctionOfDim(MlirAffineExpr affineExpr, MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprCompose( MlirAffineExpr affineExpr, struct MlirAffineMap affineMap); +/// Replace dims[offset ... numDims) +/// by dims[offset + shift ... shift + numDims). +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirAffineExprShiftDims(MlirAffineExpr affineExpr, uint32_t numDims, + uint32_t shift, uint32_t offset); + +/// Replace symbols[offset ... numSymbols) +/// by symbols[offset + shift ... shift + numSymbols). +MLIR_CAPI_EXPORTED MlirAffineExpr +mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols, + uint32_t shift, uint32_t offset); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index a2df824f5..3c95d29c4 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -580,6 +580,25 @@ void mlir::python::populateIRAffine(nb::module_ &m) { return PyAffineExpr(self.getContext(), mlirAffineExprCompose(self, other)); }) + .def( + "shift_dims", + [](PyAffineExpr &self, uint32_t numDims, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftDims(self, numDims, shift, offset)); + }, + nb::arg("num_dims"), nb::arg("shift"), nb::arg("offset").none() = 0) + .def( + "shift_symbols", + [](PyAffineExpr &self, uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return PyAffineExpr( + self.getContext(), + mlirAffineExprShiftSymbols(self, numSymbols, shift, offset)); + }, + nb::arg("num_symbols"), nb::arg("shift"), + nb::arg("offset").none() = 0) .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index 6e3328b65..bc3dcd417 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -61,6 +61,18 @@ MlirAffineExpr mlirAffineExprCompose(MlirAffineExpr affineExpr, return wrap(unwrap(affineExpr).compose(unwrap(affineMap))); } +MlirAffineExpr mlirAffineExprShiftDims(MlirAffineExpr affineExpr, + uint32_t numDims, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftDims(numDims, shift, offset)); +} + +MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, + uint32_t numSymbols, uint32_t shift, + uint32_t offset) { + return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset)); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// From 45539c624ee693aee2a37fee971e9866e3dcc6f6 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Fri, 21 Mar 2025 04:13:13 +0000 Subject: [PATCH 864/915] [MLIR] [python] A few improvements to the Python bindings (#131686) * `PyRegionList` is now sliceable. The dialect bindings generator seems to assume it is sliceable already (!), yet accessing e.g. `cases` on `scf.IndexedSwitchOp` raises a `TypeError` at runtime. * `PyBlockList` and `PyOperationList` support negative indexing. It is common for containers to do that in Python, and most container in the MLIR Python bindings already allow the index to be negative. --- mlir/lib/Bindings/Python/IRCore.cpp | 49 ++++++++++++++++-------- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 3 ++ 2 files changed, 36 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 78ba144ac..5ffcf6717 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -361,37 +361,45 @@ class PyRegionIterator { /// Regions of an op are fixed length and indexed numerically so are represented /// with a sequence-like container. -class PyRegionList { +class PyRegionList : public Sliceable { public: - PyRegionList(PyOperationRef operation) : operation(std::move(operation)) {} + static constexpr const char *pyClassName = "RegionSequence"; + + PyRegionList(PyOperationRef operation, intptr_t startIndex = 0, + intptr_t length = -1, intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirOperationGetNumRegions(operation->get()) + : length, + step), + operation(std::move(operation)) {} PyRegionIterator dunderIter() { operation->checkValid(); return PyRegionIterator(operation); } - intptr_t dunderLen() { + static void bindDerived(ClassTy &c) { + c.def("__iter__", &PyRegionList::dunderIter); + } + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { operation->checkValid(); return mlirOperationGetNumRegions(operation->get()); } - PyRegion dunderGetItem(intptr_t index) { - // dunderLen checks validity. - if (index < 0 || index >= dunderLen()) { - throw nb::index_error("attempt to access out of bounds region"); - } - MlirRegion region = mlirOperationGetRegion(operation->get(), index); - return PyRegion(operation, region); + PyRegion getRawElement(intptr_t pos) { + operation->checkValid(); + return PyRegion(operation, mlirOperationGetRegion(operation->get(), pos)); } - static void bind(nb::module_ &m) { - nb::class_(m, "RegionSequence") - .def("__len__", &PyRegionList::dunderLen) - .def("__iter__", &PyRegionList::dunderIter) - .def("__getitem__", &PyRegionList::dunderGetItem); + PyRegionList slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyRegionList(operation, startIndex, length, step); } -private: PyOperationRef operation; }; @@ -450,6 +458,9 @@ class PyBlockList { PyBlock dunderGetItem(intptr_t index) { operation->checkValid(); + if (index < 0) { + index += dunderLen(); + } if (index < 0) { throw nb::index_error("attempt to access out of bounds block"); } @@ -546,6 +557,9 @@ class PyOperationList { nb::object dunderGetItem(intptr_t index) { parentOperation->checkValid(); + if (index < 0) { + index += dunderLen(); + } if (index < 0) { throw nb::index_error("attempt to access out of bounds operation"); } @@ -2629,6 +2643,9 @@ class PyOpAttributeMap { } PyNamedAttribute dunderGetItemIndexed(intptr_t index) { + if (index < 0) { + index += dunderLen(); + } if (index < 0 || index >= dunderLen()) { throw nb::index_error("attempt to access out of bounds attribute"); } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index c93de2fe3..c60ff72ff 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -2466,7 +2466,10 @@ class RegionIterator: def __next__(self) -> Region: ... class RegionSequence: + @overload def __getitem__(self, arg0: int) -> Region: ... + @overload + def __getitem__(self, arg0: slice) -> Sequence[Region]: ... def __iter__(self) -> RegionIterator: ... def __len__(self) -> int: ... From 8cc42ae53195a0a7f7fa1c28fdb0a4ba9e74d0b7 Mon Sep 17 00:00:00 2001 From: Sandeep Dasgupta Date: Sun, 23 Mar 2025 05:37:55 -0700 Subject: [PATCH 865/915] Sub-channel quantized type implementation (#120172) This is an implementation for [RFC: Supporting Sub-Channel Quantization in MLIR](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694). In order to make the review process easier, the PR has been divided into the following commit labels: 1. **Add implementation for sub-channel type:** Includes the class design for `UniformQuantizedSubChannelType`, printer/parser and bytecode read/write support. The existing types (per-tensor and per-axis) are unaltered. 2. **Add implementation for sub-channel type:** Lowering of `quant.qcast` and `quant.dcast` operations to Linalg operations. 3. **Adding C/Python Apis:** We first define he C-APIs and build the Python-APIs on top of those. 4. **Add pass to normalize generic ....:** This pass normalizes sub-channel quantized types to per-tensor per-axis types, if possible. A design note: - **Explicitly storing the `quantized_dimensions`, even when they can be derived for ranked tensor.** While it's possible to infer quantized dimensions from the static shape of the scales (or zero-points) tensor for ranked data tensors ([ref](https://discourse.llvm.org/t/rfc-supporting-sub-channel-quantization-in-mlir/82694/3) for background), there are cases where this can lead to ambiguity and issues with round-tripping. ``` Consider the example: tensor<2x4x!quant.uniform> ``` The shape of the scales tensor is [1, 2], which might suggest that only axis 1 is quantized. While this inference is technically correct, as the block size for axis 0 is a degenerate case (equal to the dimension size), it can cause problems with round-tripping. Therefore, even for ranked tensors, we are explicitly storing the quantized dimensions. Suggestions welcome! PS: I understand that the upcoming holidays may impact your schedule, so please take your time with the review. There's no rush. --- mlir/include/mlir-c/Dialect/Quant.h | 41 ++++++++++ mlir/lib/Bindings/Python/DialectQuant.cpp | 76 ++++++++++++++++++- mlir/lib/CAPI/Dialect/Quant.cpp | 56 ++++++++++++++ .../mlir/_mlir_libs/_mlir/dialects/quant.pyi | 22 +++++- 4 files changed, 193 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Quant.h b/mlir/include/mlir-c/Dialect/Quant.h index a7d98dc3c..dc0989e53 100644 --- a/mlir/include/mlir-c/Dialect/Quant.h +++ b/mlir/include/mlir-c/Dialect/Quant.h @@ -172,6 +172,47 @@ mlirUniformQuantizedPerAxisTypeGetQuantizedDimension(MlirType type); MLIR_CAPI_EXPORTED bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type); +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +/// Returns `true` if the given type is a UniformQuantizedSubChannel. +MLIR_CAPI_EXPORTED bool +mlirTypeIsAUniformQuantizedSubChannelType(MlirType type); + +/// Creates a UniformQuantizedSubChannelType with the given parameters. +/// +/// The type is owned by the context. `scalesAttr` and `zeroPointsAttr` must be +/// DenseElementsAttrs. `quantizedDimensions` and `blockSizes` +/// point to `blockSizeInfoLength` number of elements, describing respectively +/// the quantization axis and corresponding block size. +MLIR_CAPI_EXPORTED MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, + intptr_t blockSizeInfoLength, int32_t *quantizedDimensions, + int64_t *blockSizes, int64_t storageTypeMin, int64_t storageTypeMax); + +/// Returns the number of block sizes provided in type. +MLIR_CAPI_EXPORTED intptr_t +mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type); + +/// Returns the quantized dimension at the given position. +MLIR_CAPI_EXPORTED int32_t +mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos); + +/// Returns the block size at the given position. +MLIR_CAPI_EXPORTED int64_t +mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, intptr_t pos); + +/// Returns the scales of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetScales(MlirType type); + +/// Returns the zero-points of the quantized type. +MLIR_CAPI_EXPORTED MlirAttribute +mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type); + //===---------------------------------------------------------------------===// // CalibratedQuantizedType //===---------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/DialectQuant.cpp b/mlir/lib/Bindings/Python/DialectQuant.cpp index 29f19c9c5..55571cd1e 100644 --- a/mlir/lib/Bindings/Python/DialectQuant.cpp +++ b/mlir/lib/Bindings/Python/DialectQuant.cpp @@ -9,10 +9,11 @@ #include #include +#include "mlir-c/BuiltinAttributes.h" #include "mlir-c/Dialect/Quant.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace llvm; @@ -284,6 +285,79 @@ static void populateDialectQuantSubmodule(const nb::module_ &m) { }, "Fixed point values are real numbers divided by a scale."); + //===-------------------------------------------------------------------===// + // UniformQuantizedSubChannelType + //===-------------------------------------------------------------------===// + auto uniformQuantizedSubChannelType = mlir_type_subclass( + m, "UniformQuantizedSubChannelType", + mlirTypeIsAUniformQuantizedSubChannelType, quantizedType.get_class()); + uniformQuantizedSubChannelType.def_classmethod( + "get", + [](nb::object cls, unsigned flags, MlirType storageType, + MlirType expressedType, MlirAttribute scales, MlirAttribute zeroPoints, + std::vector quantizedDimensions, + std::vector blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + return cls(mlirUniformQuantizedSubChannelTypeGet( + flags, storageType, expressedType, scales, zeroPoints, + static_cast(blockSizes.size()), + quantizedDimensions.data(), blockSizes.data(), storageTypeMin, + storageTypeMax)); + }, + "Gets an instance of UniformQuantizedSubChannel in the same context as " + "the provided storage type.", + nb::arg("cls"), nb::arg("flags"), nb::arg("storage_type"), + nb::arg("expressed_type"), nb::arg("scales"), nb::arg("zero_points"), + nb::arg("quantized_dimensions"), nb::arg("block_sizes"), + nb::arg("storage_type_min"), nb::arg("storage_type_max")); + uniformQuantizedSubChannelType.def_property_readonly( + "quantized_dimensions", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector quantizedDimensions; + quantizedDimensions.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + quantizedDimensions.push_back( + mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(type, i)); + } + return quantizedDimensions; + }, + "Gets the quantized dimensions. Each element in the returned list " + "represents an axis of the quantized data tensor that has a specified " + "block size. The order of elements corresponds to the order of block " + "sizes returned by 'block_sizes' method. It means that the data tensor " + "is quantized along the i-th dimension in the returned list using the " + "i-th block size from block_sizes method."); + uniformQuantizedSubChannelType.def_property_readonly( + "block_sizes", + [](MlirType type) { + intptr_t nDim = + mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(type); + std::vector blockSizes; + blockSizes.reserve(nDim); + for (intptr_t i = 0; i < nDim; ++i) { + blockSizes.push_back( + mlirUniformQuantizedSubChannelTypeGetBlockSize(type, i)); + } + return blockSizes; + }, + "Gets the block sizes for the quantized dimensions. The i-th element in " + "the returned list corresponds to the block size for the i-th dimension " + "in the list returned by quantized_dimensions method."); + uniformQuantizedSubChannelType.def_property_readonly( + "scales", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetScales(type); + }, + "The scales of the quantized type."); + uniformQuantizedSubChannelType.def_property_readonly( + "zero_points", + [](MlirType type) -> MlirAttribute { + return mlirUniformQuantizedSubChannelTypeGetZeroPoints(type); + }, + "The zero points of the quantized type."); + //===-------------------------------------------------------------------===// // CalibratedQuantizedType //===-------------------------------------------------------------------===// diff --git a/mlir/lib/CAPI/Dialect/Quant.cpp b/mlir/lib/CAPI/Dialect/Quant.cpp index c94dbb569..01a6a948f 100644 --- a/mlir/lib/CAPI/Dialect/Quant.cpp +++ b/mlir/lib/CAPI/Dialect/Quant.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir-c/Dialect/Quant.h" +#include "mlir-c/BuiltinAttributes.h" #include "mlir/CAPI/Registration.h" #include "mlir/Dialect/Quant/IR/Quant.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" @@ -194,6 +195,61 @@ bool mlirUniformQuantizedPerAxisTypeIsFixedPoint(MlirType type) { return cast(unwrap(type)).isFixedPoint(); } +//===---------------------------------------------------------------------===// +// UniformQuantizedSubChannelType +//===---------------------------------------------------------------------===// + +bool mlirTypeIsAUniformQuantizedSubChannelType(MlirType type) { + return isa(unwrap(type)); +} + +MlirType mlirUniformQuantizedSubChannelTypeGet( + unsigned flags, MlirType storageType, MlirType expressedType, + MlirAttribute scalesAttr, MlirAttribute zeroPointsAttr, intptr_t nDims, + int32_t *quantizedDimensions, int64_t *blockSizes, int64_t storageTypeMin, + int64_t storageTypeMax) { + auto scales = dyn_cast(unwrap(scalesAttr)); + auto zeroPoints = dyn_cast(unwrap(zeroPointsAttr)); + + if (!scales || !zeroPoints) { + return {}; + } + + return wrap(quant::UniformQuantizedSubChannelType::get( + flags, unwrap(storageType), unwrap(expressedType), scales, zeroPoints, + llvm::ArrayRef(quantizedDimensions, nDims), + llvm::ArrayRef(blockSizes, nDims), storageTypeMin, + storageTypeMax)); +} + +intptr_t mlirUniformQuantizedSubChannelTypeGetNumBlockSizes(MlirType type) { + return cast(unwrap(type)) + .getBlockSizes() + .size(); +} + +int32_t mlirUniformQuantizedSubChannelTypeGetQuantizedDimension(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getQuantizedDimensions()[pos]; +} + +int64_t mlirUniformQuantizedSubChannelTypeGetBlockSize(MlirType type, + intptr_t pos) { + return cast(unwrap(type)) + .getBlockSizes()[pos]; +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetScales(MlirType type) { + return wrap( + cast(unwrap(type)).getScales()); +} + +MlirAttribute mlirUniformQuantizedSubChannelTypeGetZeroPoints(MlirType type) { + return wrap(cast(unwrap(type)) + .getZeroPoints()); +} + //===---------------------------------------------------------------------===// // CalibratedQuantizedType //===---------------------------------------------------------------------===// diff --git a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi index 47168d49c..3f5304584 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/dialects/quant.pyi @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from mlir.ir import Type +from mlir.ir import DenseElementsAttr, Type __all__ = [ "QuantizedType", @@ -109,6 +109,26 @@ class UniformQuantizedPerAxisType(QuantizedType): @property def is_fixed_point(self) -> bool: ... +class UniformQuantizedSubChannelType(QuantizedType): + + @classmethod + def get(cls, flags: int, storage_type: Type, expressed_type: Type, + scales: DenseElementsAttr, zero_points: DenseElementsAttr, + quantized_dimensions: list[int], block_sizes: list[int], + storage_type_min: int, storage_type_max: int): + ... + + @property + def quantized_dimensions(self) -> list[int]: ... + + @property + def block_sizes(self) -> list[int]: ... + + @property + def scales(self) -> DenseElementsAttr: ... + + @property + def zero_points(self) -> DenseElementsAttr: ... def CalibratedQuantizedType(QuantizedType): From 55a74ab74cf1e8e19cc930be51f7b2574334e7c1 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Mon, 31 Mar 2025 09:29:54 -0700 Subject: [PATCH 866/915] [MLIR][NFC] Fix incomplete boundary comments. (#133516) I observed that we have the boundary comments in the codebase like: ``` //===----------------------------------------------------------------------===// // ... //===----------------------------------------------------------------------===// ``` I also observed that there are incomplete boundary comments. The revision is generated by a script that completes the boundary comments. ``` //===----------------------------------------------------------------------===// // ... ... ``` Signed-off-by: hanhanW --- mlir/include/mlir-c/Rewrite.h | 2 ++ mlir/lib/CAPI/Transforms/Rewrite.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/mlir/include/mlir-c/Rewrite.h b/mlir/include/mlir-c/Rewrite.h index d8f2275b6..61d344631 100644 --- a/mlir/include/mlir-c/Rewrite.h +++ b/mlir/include/mlir-c/Rewrite.h @@ -48,6 +48,7 @@ mlirRewriterBaseGetContext(MlirRewriterBase rewriter); //===----------------------------------------------------------------------===// /// Insertion points methods +//===----------------------------------------------------------------------===// // These do not include functions using Block::iterator or Region::iterator, as // they are not exposed by the C API yet. Similarly for methods using @@ -101,6 +102,7 @@ mlirRewriterBaseGetBlock(MlirRewriterBase rewriter); //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning +//===----------------------------------------------------------------------===// // These functions do not include the IRMapper, as it is not yet exposed by the // C API. diff --git a/mlir/lib/CAPI/Transforms/Rewrite.cpp b/mlir/lib/CAPI/Transforms/Rewrite.cpp index c4717ca61..a4df97f7b 100644 --- a/mlir/lib/CAPI/Transforms/Rewrite.cpp +++ b/mlir/lib/CAPI/Transforms/Rewrite.cpp @@ -29,6 +29,7 @@ MlirContext mlirRewriterBaseGetContext(MlirRewriterBase rewriter) { //===----------------------------------------------------------------------===// /// Insertion points methods +//===----------------------------------------------------------------------===// void mlirRewriterBaseClearInsertionPoint(MlirRewriterBase rewriter) { unwrap(rewriter)->clearInsertionPoint(); @@ -69,6 +70,7 @@ MlirBlock mlirRewriterBaseGetBlock(MlirRewriterBase rewriter) { //===----------------------------------------------------------------------===// /// Block and operation creation/insertion/cloning +//===----------------------------------------------------------------------===// MlirBlock mlirRewriterBaseCreateBlockBefore(MlirRewriterBase rewriter, MlirBlock insertBefore, From c2a640c3452900a1fe5d25a24546cea8b954a8b8 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 1 Apr 2025 18:28:53 +0200 Subject: [PATCH 867/915] [mlir] Expose `simplifyAffineExpr` through python api (#133926) --- mlir/include/mlir-c/AffineExpr.h | 10 ++++++++++ mlir/lib/Bindings/Python/IRAffine.cpp | 10 ++++++++++ mlir/lib/CAPI/IR/AffineExpr.cpp | 5 +++++ 3 files changed, 25 insertions(+) diff --git a/mlir/include/mlir-c/AffineExpr.h b/mlir/include/mlir-c/AffineExpr.h index ab768eb2e..161db6266 100644 --- a/mlir/include/mlir-c/AffineExpr.h +++ b/mlir/include/mlir-c/AffineExpr.h @@ -104,6 +104,16 @@ MLIR_CAPI_EXPORTED MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, uint32_t numSymbols, uint32_t shift, uint32_t offset); +/// Simplify an affine expression by flattening and some amount of simple +/// analysis. This has complexity linear in the number of nodes in 'expr'. +/// Returns the simplified expression, which is the same as the input expression +/// if it can't be simplified. When `expr` is semi-affine, a simplified +/// semi-affine expression is constructed in the sorted order of dimension and +/// symbol positions. +MLIR_CAPI_EXPORTED MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, + uint32_t numDims, + uint32_t numSymbols); + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRAffine.cpp b/mlir/lib/Bindings/Python/IRAffine.cpp index 3c95d29c4..50f2a4f95 100644 --- a/mlir/lib/Bindings/Python/IRAffine.cpp +++ b/mlir/lib/Bindings/Python/IRAffine.cpp @@ -599,6 +599,16 @@ void mlir::python::populateIRAffine(nb::module_ &m) { }, nb::arg("num_symbols"), nb::arg("shift"), nb::arg("offset").none() = 0) + .def_static( + "simplify_affine_expr", + [](PyAffineExpr &self, uint32_t numDims, uint32_t numSymbols) { + return PyAffineExpr( + self.getContext(), + mlirSimplifyAffineExpr(self, numDims, numSymbols)); + }, + nb::arg("expr"), nb::arg("num_dims"), nb::arg("num_symbols"), + "Simplify an affine expression by flattening and some amount of " + "simple analysis.") .def_static( "get_add", &PyAffineAddExpr::get, "Gets an affine expression containing a sum of two expressions.") diff --git a/mlir/lib/CAPI/IR/AffineExpr.cpp b/mlir/lib/CAPI/IR/AffineExpr.cpp index bc3dcd417..5a0a03b11 100644 --- a/mlir/lib/CAPI/IR/AffineExpr.cpp +++ b/mlir/lib/CAPI/IR/AffineExpr.cpp @@ -73,6 +73,11 @@ MlirAffineExpr mlirAffineExprShiftSymbols(MlirAffineExpr affineExpr, return wrap(unwrap(affineExpr).shiftSymbols(numSymbols, shift, offset)); } +MlirAffineExpr mlirSimplifyAffineExpr(MlirAffineExpr expr, uint32_t numDims, + uint32_t numSymbols) { + return wrap(simplifyAffineExpr(unwrap(expr), numDims, numSymbols)); +} + //===----------------------------------------------------------------------===// // Affine Dimension Expression. //===----------------------------------------------------------------------===// From 68c7d91b4bcbe48e9caa724cbe40e68d2192756b Mon Sep 17 00:00:00 2001 From: Nirvedh Meshram <96096277+nirvedhmeshram@users.noreply.github.com> Date: Tue, 1 Apr 2025 11:55:40 -0500 Subject: [PATCH 868/915] [NFC][mlir] Update generate script for conv_3d_ncdhw_fcdhw (#133927) https://github.com/llvm/llvm-project/pull/129547 changed the IR directly without updating the auto generate script. Signed-off-by: Nirvedh --- mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 040663c88..48e724d80 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -1140,7 +1140,7 @@ def conv_3d_ncdhw_fcdhw( them to the same data type as the accumulator/output. """ implements(ConvolutionOpInterface) - domain(D.n, D.od, D.oh, D.ow, D.f, D.kd, D.kh, D.kw, D.c) + domain(D.n, D.f, D.od, D.oh, D.ow, D.c, D.kd, D.kh, D.kw) O[D.n, D.f, D.od, D.oh, D.ow] += TypeFn.cast_signed( U, I[ From 05f5e65173f51b4c86a8f23c9df45c6950f25231 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 9 Apr 2025 19:28:59 -0400 Subject: [PATCH 869/915] [mlir][python] add use_name_loc_as_prefix to value.get_name() (#135052) Add `use_name_loc_as_prefix` to `value.get_name()`. --- mlir/lib/Bindings/Python/IRCore.cpp | 7 +++++-- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 4 ++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5ffcf6717..b5720b7ad 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3977,11 +3977,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { kValueDunderStrDocstring) .def( "get_name", - [](PyValue &self, bool useLocalScope) { + [](PyValue &self, bool useLocalScope, bool useNameLocAsPrefix) { PyPrintAccumulator printAccum; MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); if (useLocalScope) mlirOpPrintingFlagsUseLocalScope(flags); + if (useNameLocAsPrefix) + mlirOpPrintingFlagsPrintNameLocAsPrefix(flags); MlirAsmState valueState = mlirAsmStateCreateForValue(self.get(), flags); mlirValuePrintAsOperand(self.get(), valueState, @@ -3991,7 +3993,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { mlirAsmStateDestroy(valueState); return printAccum.join(); }, - nb::arg("use_local_scope") = false) + nb::arg("use_local_scope") = false, + nb::arg("use_name_loc_as_prefix") = false) .def( "get_name", [](PyValue &self, PyAsmState &state) { diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index c60ff72ff..1c8080c5d 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -577,7 +577,7 @@ class Value: Dumps a debug representation of the object to stderr. """ @overload - def get_name(self, use_local_scope: bool = False) -> str: ... + def get_name(self, use_local_scope: bool = False, use_name_loc_as_prefix: bool = True) -> str: ... @overload def get_name(self, state: AsmState) -> str: """ @@ -2382,7 +2382,7 @@ class Operation(_OperationBase): attributes: Dict of str:Attribute. successors: List of Block for the operation's successors. regions: Number of regions to create. - location: A Location object (defaults to resolve from context manager). + loc: A Location object (defaults to resolve from context manager). ip: An InsertionPoint (defaults to resolve from context manager or set to False to disable insertion, even with an insertion point set in the context manager). From fffbadf86e00e31cca5cc6f80377d990901ff869 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Wed, 9 Apr 2025 20:01:38 -0400 Subject: [PATCH 870/915] [mlir][CAPI][python] expose the python bindings for linalg::isaContractionOpInterface and linalg::inferContractionDims (#134935) This PR is mainly about exposing the python bindings for` linalg::isaContractionOpInterface` and` linalg::inferContractionDims`. --------- Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 12 +++++++ mlir/lib/Bindings/Python/DialectLinalg.cpp | 41 +++++++++++++++++++++- mlir/lib/CAPI/Dialect/Linalg.cpp | 34 ++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 0ab201e15..c57d193e6 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -22,6 +22,18 @@ extern "C" { MLIR_CAPI_EXPORTED void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); +MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op); + +struct MlirLinalgContractionDimensions { + MlirAttribute batch; + MlirAttribute m; + MlirAttribute n; + MlirAttribute k; +}; + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 548df4ee1..978ea8664 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -8,10 +8,25 @@ #include "mlir-c/Dialect/Linalg.h" #include "mlir-c/IR.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; +using namespace mlir::python::nanobind_adaptors; + +static std::optional +InferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions dims = + mlirLinalgInferContractionDimensions(op); + + // Detect "empty" result. This occurs when `op` is not a contraction op, + // or when `linalg::inferContractionDims` fails. + if (mlirAttributeIsNull(dims.batch) && mlirAttributeIsNull(dims.m) && + mlirAttributeIsNull(dims.n) && mlirAttributeIsNull(dims.k)) { + return std::nullopt; + } + return dims; +} static void populateDialectLinalgSubmodule(nb::module_ m) { m.def( @@ -20,6 +35,30 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { nb::arg("op"), "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); + + m.def("isa_contraction_op", &mlirLinalgIsContractionOp, + "Checks if the given operation is a Linalg contraction operation.", + nb::arg("op")); + + nb::class_(m, "ContractionDimensions") + .def_prop_ro("batch", + [](const MlirLinalgContractionDimensions &self) { + return self.batch; + }) + .def_prop_ro( + "m", + [](const MlirLinalgContractionDimensions &self) { return self.m; }) + .def_prop_ro( + "n", + [](const MlirLinalgContractionDimensions &self) { return self.n; }) + .def_prop_ro("k", [](const MlirLinalgContractionDimensions &self) { + return self.k; + }); + + m.def("infer_contraction_dimensions", &InferContractionDimensions, + "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction " + "op.", + nb::arg("op")); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 2fb5bc651..362b89bde 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -41,4 +41,38 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { fun(b, *body, op->getAttrs()); } +MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + // isaContractionOpInterface handles null linalgOp internally. + return linalg::isaContractionOpInterface(linalgOp); +} + +MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions +mlirLinalgInferContractionDimensions(MlirOperation op) { + MlirLinalgContractionDimensions result{}; + auto linalgOp = dyn_cast(unwrap(op)); + if (!linalgOp) + return result; + + FailureOr maybeDims = + linalg::inferContractionDims(linalgOp); + if (failed(maybeDims)) + return result; + + linalg::ContractionDimensions contractionDims = *maybeDims; + MLIRContext *ctx = linalgOp.getContext(); + + auto toAttr = [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap( + DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + result.batch = toAttr(contractionDims.batch); + result.m = toAttr(contractionDims.m); + result.n = toAttr(contractionDims.n); + result.k = toAttr(contractionDims.k); + + return result; +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From 67f08cbf59a6f344afdea45d6a5caa711fef1e14 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 10 Apr 2025 20:22:15 -0400 Subject: [PATCH 871/915] [mlir][CAPI][python] expose the python bindings for linalg::isaConvolutionOpInterface and linalg::inferConvolutionDims (#135253) This PR is mainly about exposing the python bindings for `linalg::isaConvolutionOpInterface` and `linalg::inferConvolutionDims`. --------- Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 18 ++++++- mlir/lib/Bindings/Python/DialectLinalg.cpp | 63 +++++++++++++++++++++- mlir/lib/CAPI/Dialect/Linalg.cpp | 47 +++++++++++++++- 3 files changed, 125 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index c57d193e6..838c28090 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -22,7 +22,7 @@ extern "C" { MLIR_CAPI_EXPORTED void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); -MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op); +MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op); struct MlirLinalgContractionDimensions { MlirAttribute batch; @@ -34,6 +34,22 @@ struct MlirLinalgContractionDimensions { MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); + +struct MlirLinalgConvolutionDimensions { + MlirAttribute batch; + MlirAttribute outputImage; + MlirAttribute outputChannel; + MlirAttribute filterLoop; + MlirAttribute inputChannel; + MlirAttribute depth; + MlirAttribute strides; + MlirAttribute dilations; +}; + +MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions +mlirLinalgInferConvolutionDimensions(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index 978ea8664..ce1102a3b 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -28,6 +28,26 @@ InferContractionDimensions(MlirOperation op) { return dims; } +static std::optional +InferConvolutionDimensions(MlirOperation op) { + MlirLinalgConvolutionDimensions dims = + mlirLinalgInferConvolutionDimensions(op); + + // Detect "empty" result. This occurs when `op` is not a convolution op, + // or when `linalg::inferConvolutionDims` fails. + if (mlirAttributeIsNull(dims.batch) && + mlirAttributeIsNull(dims.outputImage) && + mlirAttributeIsNull(dims.outputChannel) && + mlirAttributeIsNull(dims.filterLoop) && + mlirAttributeIsNull(dims.inputChannel) && + mlirAttributeIsNull(dims.depth) && mlirAttributeIsNull(dims.strides) && + mlirAttributeIsNull(dims.dilations)) { + return std::nullopt; + } + + return dims; +} + static void populateDialectLinalgSubmodule(nb::module_ m) { m.def( "fill_builtin_region", @@ -36,7 +56,7 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "Fill the region for `op`, which is assumed to be a builtin named Linalg " "op."); - m.def("isa_contraction_op", &mlirLinalgIsContractionOp, + m.def("isa_contraction_op", &mlirLinalgIsAContractionOp, "Checks if the given operation is a Linalg contraction operation.", nb::arg("op")); @@ -59,6 +79,47 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { "Infers contraction dimensions (batch/m/n/k) for a Linalg contraction " "op.", nb::arg("op")); + + m.def("isa_convolution_op", &mlirLinalgIsAConvolutionOp, + "Checks if the given operation is a Linalg convolution operation.", + nb::arg("op")); + + nb::class_(m, "ConvolutionDimensions") + .def_prop_ro("batch", + [](const MlirLinalgConvolutionDimensions &self) { + return self.batch; + }) + .def_prop_ro("output_image", + [](const MlirLinalgConvolutionDimensions &self) { + return self.outputImage; + }) + .def_prop_ro("output_channel", + [](const MlirLinalgConvolutionDimensions &self) { + return self.outputChannel; + }) + .def_prop_ro("filter_loop", + [](const MlirLinalgConvolutionDimensions &self) { + return self.filterLoop; + }) + .def_prop_ro("input_channel", + [](const MlirLinalgConvolutionDimensions &self) { + return self.inputChannel; + }) + .def_prop_ro("depth", + [](const MlirLinalgConvolutionDimensions &self) { + return self.depth; + }) + .def_prop_ro("strides", + [](const MlirLinalgConvolutionDimensions &self) { + return self.strides; + }) + .def_prop_ro("dilations", + [](const MlirLinalgConvolutionDimensions &self) { + return self.dilations; + }); + + m.def("infer_convolution_dimensions", &InferConvolutionDimensions, + "Infers convolution dimensions", nb::arg("op")); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 362b89bde..7c456102a 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -41,7 +41,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { fun(b, *body, op->getAttrs()); } -MLIR_CAPI_EXPORTED bool mlirLinalgIsContractionOp(MlirOperation op) { +MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) { auto linalgOp = llvm::dyn_cast(unwrap(op)); // isaContractionOpInterface handles null linalgOp internally. return linalg::isaContractionOpInterface(linalgOp); @@ -75,4 +75,49 @@ mlirLinalgInferContractionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return false; + + return linalg::isaConvolutionOpInterface(linalgOp); +} + +MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions +mlirLinalgInferConvolutionDimensions(MlirOperation op) { + MlirLinalgConvolutionDimensions result{}; + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return result; + + FailureOr maybeDims = + linalg::inferConvolutionDims(linalgOp); + if (failed(maybeDims)) + return result; + + linalg::ConvolutionDimensions dims = *maybeDims; + MLIRContext *ctx = linalgOp.getContext(); + + auto toI32Attr = + [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap(DenseI32ArrayAttr::get(ctx, llvm::to_vector_of(vals))); + }; + + auto toI64Attr = + [&ctx](const SmallVector &vals) -> MlirAttribute { + return wrap(DenseI64ArrayAttr::get(ctx, vals)); + }; + + result.batch = toI32Attr(dims.batch); + result.outputImage = toI32Attr(dims.outputImage); + result.outputChannel = toI32Attr(dims.outputChannel); + result.filterLoop = toI32Attr(dims.filterLoop); + result.inputChannel = toI32Attr(dims.inputChannel); + result.depth = toI32Attr(dims.depth); + result.strides = toI64Attr(dims.strides); + result.dilations = toI64Attr(dims.dilations); + + return result; +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From 1ba23ff4e6fb8fa9775642ab1c2cf70db604bf04 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Fri, 11 Apr 2025 11:16:58 -0400 Subject: [PATCH 872/915] [MLIR][CAPI] add C API typedef to fix downstream C API usage (#135380) This PR is after #135253 and #134935 to fix the error reported by https://github.com/llvm/llvm-project/pull/135253#issuecomment-2796077024. This PR Adds typedef declarations for `MlirLinalgContractionDimensions` and `MlirLinalgConvolutionDimensions` in the C API to ensure compatibility with pure C code. I confirm that this fix resolves the reported error based on my testing. Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 838c28090..4f2ee0d43 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -24,19 +24,19 @@ mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp); MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op); -struct MlirLinalgContractionDimensions { +typedef struct MlirLinalgContractionDimensions { MlirAttribute batch; MlirAttribute m; MlirAttribute n; MlirAttribute k; -}; +} MlirLinalgContractionDimensions; MLIR_CAPI_EXPORTED MlirLinalgContractionDimensions mlirLinalgInferContractionDimensions(MlirOperation op); MLIR_CAPI_EXPORTED bool mlirLinalgIsAConvolutionOp(MlirOperation op); -struct MlirLinalgConvolutionDimensions { +typedef struct MlirLinalgConvolutionDimensions { MlirAttribute batch; MlirAttribute outputImage; MlirAttribute outputChannel; @@ -45,7 +45,7 @@ struct MlirLinalgConvolutionDimensions { MlirAttribute depth; MlirAttribute strides; MlirAttribute dilations; -}; +} MlirLinalgConvolutionDimensions; MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op); From 87410a65d719c355dd12d9595dc6739350e0ee03 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 14 Apr 2025 15:37:14 -0400 Subject: [PATCH 873/915] [mlir][SMT] C APIs (#135501) This PR upstreams/adds the C APIs for SMT dialect (from CIRCT). --------- Co-authored-by: Bea Healy Co-authored-by: Martin Erhart Co-authored-by: Mike Urbach Co-authored-by: Will Dietz Co-authored-by: fzi-hielscher Co-authored-by: Fehr Mathieu Co-authored-by: Clo91eaf --- mlir/include/mlir-c/Dialect/SMT.h | 110 +++++++++++++++++++ mlir/include/mlir-c/Target/ExportSMTLIB.h | 32 ++++++ mlir/lib/CAPI/Dialect/CMakeLists.txt | 9 ++ mlir/lib/CAPI/Dialect/SMT.cpp | 123 ++++++++++++++++++++++ mlir/lib/CAPI/Target/CMakeLists.txt | 12 +++ mlir/lib/CAPI/Target/ExportSMTLIB.cpp | 27 +++++ 6 files changed, 313 insertions(+) create mode 100644 mlir/include/mlir-c/Dialect/SMT.h create mode 100644 mlir/include/mlir-c/Target/ExportSMTLIB.h create mode 100644 mlir/lib/CAPI/Dialect/SMT.cpp create mode 100644 mlir/lib/CAPI/Target/ExportSMTLIB.cpp diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h new file mode 100644 index 000000000..d076dccce --- /dev/null +++ b/mlir/include/mlir-c/Dialect/SMT.h @@ -0,0 +1,110 @@ +//===- SMT.h - C interface for the SMT dialect --------------------*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_DIALECT_SMT_H +#define MLIR_C_DIALECT_SMT_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt); + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +/// Checks if the given type is any non-func SMT value type. +MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type); + +/// Checks if the given type is any SMT value type. +MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type); + +/// Checks if the given type is a smt::ArrayType. +MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type); + +/// Creates an array type with the given domain and range types. +MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx, + MlirType domainType, + MlirType rangeType); + +/// Checks if the given type is a smt::BitVectorType. +MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type); + +/// Creates a smt::BitVectorType with the given width. +MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width); + +/// Checks if the given type is a smt::BoolType. +MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type); + +/// Creates a smt::BoolType. +MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx); + +/// Checks if the given type is a smt::IntType. +MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type); + +/// Creates a smt::IntType. +MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx); + +/// Checks if the given type is a smt::FuncType. +MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type); + +/// Creates a smt::FuncType with the given domain and range types. +MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx, + size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType); + +/// Checks if the given type is a smt::SortType. +MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type); + +/// Creates a smt::SortType with the given identifier and sort parameters. +MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx, + MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams); + +//===----------------------------------------------------------------------===// +// Attribute API. +//===----------------------------------------------------------------------===// + +/// Checks if the given string is a valid smt::BVCmpPredicate. +MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx, + MlirStringRef str); + +/// Checks if the given string is a valid smt::IntPredicate. +MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx, + MlirStringRef str); + +/// Checks if the given attribute is a smt::SMTAttribute. +MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr); + +/// Creates a smt::BitVectorAttr with the given value and width. +MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx, + uint64_t value, + unsigned width); + +/// Creates a smt::BVCmpPredicateAttr with the given string. +MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, + MlirStringRef str); + +/// Creates a smt::IntPredicateAttr with the given string. +MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, + MlirStringRef str); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_DIALECT_SMT_H diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h new file mode 100644 index 000000000..31f411c4a --- /dev/null +++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h @@ -0,0 +1,32 @@ +//===- mlir-c/Target/ExportSMTLIB.h - C API for emitting SMTLIB ---*- C -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This header declares the C interface for emitting SMTLIB from an MLIR module. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_C_EXPORTSMTLIB_H +#define MLIR_C_EXPORTSMTLIB_H + +#include "mlir-c/IR.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// Emits SMTLIB for the specified module using the provided callback and user +/// data +MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule, + MlirStringCallback, + void *userData); + +#ifdef __cplusplus +} +#endif + +#endif // MLIR_C_EXPORTSMTLIB_H diff --git a/mlir/lib/CAPI/Dialect/CMakeLists.txt b/mlir/lib/CAPI/Dialect/CMakeLists.txt index ddd3d6629..bb1fdf8be 100644 --- a/mlir/lib/CAPI/Dialect/CMakeLists.txt +++ b/mlir/lib/CAPI/Dialect/CMakeLists.txt @@ -269,3 +269,12 @@ add_mlir_upstream_c_api_library(MLIRCAPIVector MLIRCAPIIR MLIRVectorDialect ) + +add_mlir_upstream_c_api_library(MLIRCAPISMT + SMT.cpp + + PARTIAL_SOURCES_INTENDED + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRSMT +) diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp new file mode 100644 index 000000000..3a4620df8 --- /dev/null +++ b/mlir/lib/CAPI/Dialect/SMT.cpp @@ -0,0 +1,123 @@ +//===- SMT.cpp - C interface for the SMT dialect --------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Dialect/SMT.h" +#include "mlir/CAPI/Registration.h" +#include "mlir/Dialect/SMT/IR/SMTAttributes.h" +#include "mlir/Dialect/SMT/IR/SMTDialect.h" +#include "mlir/Dialect/SMT/IR/SMTTypes.h" + +using namespace mlir; +using namespace smt; + +//===----------------------------------------------------------------------===// +// Dialect API. +//===----------------------------------------------------------------------===// + +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect) + +//===----------------------------------------------------------------------===// +// Type API. +//===----------------------------------------------------------------------===// + +bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) { + return isAnyNonFuncSMTValueType(unwrap(type)); +} + +bool smtTypeIsAnySMTValueType(MlirType type) { + return isAnySMTValueType(unwrap(type)); +} + +bool smtTypeIsAArray(MlirType type) { return isa(unwrap(type)); } + +MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType, + MlirType rangeType) { + return wrap( + ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType))); +} + +bool smtTypeIsABitVector(MlirType type) { + return isa(unwrap(type)); +} + +MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) { + return wrap(BitVectorType::get(unwrap(ctx), width)); +} + +bool smtTypeIsABool(MlirType type) { return isa(unwrap(type)); } + +MlirType smtTypeGetBool(MlirContext ctx) { + return wrap(BoolType::get(unwrap(ctx))); +} + +bool smtTypeIsAInt(MlirType type) { return isa(unwrap(type)); } + +MlirType smtTypeGetInt(MlirContext ctx) { + return wrap(IntType::get(unwrap(ctx))); +} + +bool smtTypeIsASMTFunc(MlirType type) { return isa(unwrap(type)); } + +MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, + const MlirType *domainTypes, MlirType rangeType) { + SmallVector domainTypesVec; + domainTypesVec.reserve(numberOfDomainTypes); + + for (size_t i = 0; i < numberOfDomainTypes; i++) + domainTypesVec.push_back(unwrap(domainTypes[i])); + + return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType))); +} + +bool smtTypeIsASort(MlirType type) { return isa(unwrap(type)); } + +MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier, + size_t numberOfSortParams, const MlirType *sortParams) { + SmallVector sortParamsVec; + sortParamsVec.reserve(numberOfSortParams); + + for (size_t i = 0; i < numberOfSortParams; i++) + sortParamsVec.push_back(unwrap(sortParams[i])); + + return wrap(SortType::get(unwrap(ctx), unwrap(identifier), sortParamsVec)); +} + +//===----------------------------------------------------------------------===// +// Attribute API. +//===----------------------------------------------------------------------===// + +bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { + return symbolizeBVCmpPredicate(unwrap(str)).has_value(); +} + +bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { + return symbolizeIntPredicate(unwrap(str)).has_value(); +} + +bool smtAttrIsASMTAttribute(MlirAttribute attr) { + return isa(unwrap(attr)); +} + +MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value, + unsigned width) { + return wrap(BitVectorAttr::get(unwrap(ctx), value, width)); +} + +MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { + auto predicate = symbolizeBVCmpPredicate(unwrap(str)); + assert(predicate.has_value() && "invalid predicate"); + + return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value())); +} + +MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { + auto predicate = symbolizeIntPredicate(unwrap(str)); + assert(predicate.has_value() && "invalid predicate"); + + return wrap(IntPredicateAttr::get(unwrap(ctx), predicate.value())); +} diff --git a/mlir/lib/CAPI/Target/CMakeLists.txt b/mlir/lib/CAPI/Target/CMakeLists.txt index ea617da72..8fbb7aa95 100644 --- a/mlir/lib/CAPI/Target/CMakeLists.txt +++ b/mlir/lib/CAPI/Target/CMakeLists.txt @@ -1,6 +1,8 @@ add_mlir_upstream_c_api_library(MLIRCAPITarget LLVMIR.cpp + PARTIAL_SOURCES_INTENDED + LINK_COMPONENTS Core @@ -11,3 +13,13 @@ add_mlir_upstream_c_api_library(MLIRCAPITarget MLIRLLVMIRToLLVMTranslation MLIRSupport ) + +add_mlir_upstream_c_api_library(MLIRCAPIExportSMTLIB + ExportSMTLIB.cpp + + PARTIAL_SOURCES_INTENDED + + LINK_LIBS PUBLIC + MLIRCAPIIR + MLIRExportSMTLIB +) diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp new file mode 100644 index 000000000..c9ac7ce70 --- /dev/null +++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp @@ -0,0 +1,27 @@ +//===- ExportSMTLIB.cpp - C Interface to ExportSMTLIB ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM +// Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Implements a C Interface for export SMTLIB. +// +//===----------------------------------------------------------------------===// + +#include "mlir-c/Target/ExportSMTLIB.h" +#include "mlir/CAPI/IR.h" +#include "mlir/CAPI/Support.h" +#include "mlir/CAPI/Utils.h" +#include "mlir/Target/SMTLIB/ExportSMTLIB.h" + +using namespace mlir; + +MlirLogicalResult mlirExportSMTLIB(MlirModule module, + MlirStringCallback callback, + void *userData) { + mlir::detail::CallbackOstream stream(callback, userData); + return wrap(smt::exportSMTLIB(unwrap(module), stream)); +} From 514592782d706e459d05b1eba24d598fe6253679 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 16 Apr 2025 18:17:09 -0400 Subject: [PATCH 874/915] [mlir][SMT] add python bindings (#135674) This PR adds "rich" python bindings to SMT dialect. --- mlir/include/mlir-c/Dialect/SMT.h | 69 +++++++++---------- mlir/include/mlir-c/Target/ExportSMTLIB.h | 10 ++- mlir/lib/Bindings/Python/DialectSMT.cpp | 83 +++++++++++++++++++++++ mlir/lib/CAPI/Dialect/SMT.cpp | 52 +++++++------- mlir/lib/CAPI/Target/ExportSMTLIB.cpp | 21 +++++- mlir/python/CMakeLists.txt | 24 +++++++ mlir/python/mlir/dialects/SMTOps.td | 14 ++++ mlir/python/mlir/dialects/smt.py | 33 +++++++++ 8 files changed, 242 insertions(+), 64 deletions(-) create mode 100644 mlir/lib/Bindings/Python/DialectSMT.cpp create mode 100644 mlir/python/mlir/dialects/SMTOps.td create mode 100644 mlir/python/mlir/dialects/smt.py diff --git a/mlir/include/mlir-c/Dialect/SMT.h b/mlir/include/mlir-c/Dialect/SMT.h index d076dccce..0ad64746f 100644 --- a/mlir/include/mlir-c/Dialect/SMT.h +++ b/mlir/include/mlir-c/Dialect/SMT.h @@ -26,82 +26,83 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(SMT, smt); //===----------------------------------------------------------------------===// /// Checks if the given type is any non-func SMT value type. -MLIR_CAPI_EXPORTED bool smtTypeIsAnyNonFuncSMTValueType(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type); /// Checks if the given type is any SMT value type. -MLIR_CAPI_EXPORTED bool smtTypeIsAnySMTValueType(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAnySMTValueType(MlirType type); /// Checks if the given type is a smt::ArrayType. -MLIR_CAPI_EXPORTED bool smtTypeIsAArray(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAArray(MlirType type); /// Creates an array type with the given domain and range types. -MLIR_CAPI_EXPORTED MlirType smtTypeGetArray(MlirContext ctx, - MlirType domainType, - MlirType rangeType); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetArray(MlirContext ctx, + MlirType domainType, + MlirType rangeType); /// Checks if the given type is a smt::BitVectorType. -MLIR_CAPI_EXPORTED bool smtTypeIsABitVector(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABitVector(MlirType type); /// Creates a smt::BitVectorType with the given width. -MLIR_CAPI_EXPORTED MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBitVector(MlirContext ctx, + int32_t width); /// Checks if the given type is a smt::BoolType. -MLIR_CAPI_EXPORTED bool smtTypeIsABool(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsABool(MlirType type); /// Creates a smt::BoolType. -MLIR_CAPI_EXPORTED MlirType smtTypeGetBool(MlirContext ctx); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetBool(MlirContext ctx); /// Checks if the given type is a smt::IntType. -MLIR_CAPI_EXPORTED bool smtTypeIsAInt(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsAInt(MlirType type); /// Creates a smt::IntType. -MLIR_CAPI_EXPORTED MlirType smtTypeGetInt(MlirContext ctx); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetInt(MlirContext ctx); /// Checks if the given type is a smt::FuncType. -MLIR_CAPI_EXPORTED bool smtTypeIsASMTFunc(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASMTFunc(MlirType type); /// Creates a smt::FuncType with the given domain and range types. -MLIR_CAPI_EXPORTED MlirType smtTypeGetSMTFunc(MlirContext ctx, - size_t numberOfDomainTypes, - const MlirType *domainTypes, - MlirType rangeType); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, + size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType); /// Checks if the given type is a smt::SortType. -MLIR_CAPI_EXPORTED bool smtTypeIsASort(MlirType type); +MLIR_CAPI_EXPORTED bool mlirSMTTypeIsASort(MlirType type); /// Creates a smt::SortType with the given identifier and sort parameters. -MLIR_CAPI_EXPORTED MlirType smtTypeGetSort(MlirContext ctx, - MlirIdentifier identifier, - size_t numberOfSortParams, - const MlirType *sortParams); +MLIR_CAPI_EXPORTED MlirType mlirSMTTypeGetSort(MlirContext ctx, + MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams); //===----------------------------------------------------------------------===// // Attribute API. //===----------------------------------------------------------------------===// /// Checks if the given string is a valid smt::BVCmpPredicate. -MLIR_CAPI_EXPORTED bool smtAttrCheckBVCmpPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, + MlirStringRef str); /// Checks if the given string is a valid smt::IntPredicate. -MLIR_CAPI_EXPORTED bool smtAttrCheckIntPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, + MlirStringRef str); /// Checks if the given attribute is a smt::SMTAttribute. -MLIR_CAPI_EXPORTED bool smtAttrIsASMTAttribute(MlirAttribute attr); +MLIR_CAPI_EXPORTED bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr); /// Creates a smt::BitVectorAttr with the given value and width. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBitVector(MlirContext ctx, - uint64_t value, - unsigned width); +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, + uint64_t value, + unsigned width); /// Creates a smt::BVCmpPredicateAttr with the given string. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED MlirAttribute +mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str); /// Creates a smt::IntPredicateAttr with the given string. -MLIR_CAPI_EXPORTED MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, - MlirStringRef str); +MLIR_CAPI_EXPORTED MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, + MlirStringRef str); #ifdef __cplusplus } diff --git a/mlir/include/mlir-c/Target/ExportSMTLIB.h b/mlir/include/mlir-c/Target/ExportSMTLIB.h index 31f411c4a..59beda54d 100644 --- a/mlir/include/mlir-c/Target/ExportSMTLIB.h +++ b/mlir/include/mlir-c/Target/ExportSMTLIB.h @@ -21,9 +21,13 @@ extern "C" { /// Emits SMTLIB for the specified module using the provided callback and user /// data -MLIR_CAPI_EXPORTED MlirLogicalResult mlirExportSMTLIB(MlirModule, - MlirStringCallback, - void *userData); +MLIR_CAPI_EXPORTED MlirLogicalResult +mlirTranslateModuleToSMTLIB(MlirModule, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); + +MLIR_CAPI_EXPORTED MlirLogicalResult mlirTranslateOperationToSMTLIB( + MlirOperation, MlirStringCallback, void *userData, + bool inlineSingleUseValues, bool indentLetBody); #ifdef __cplusplus } diff --git a/mlir/lib/Bindings/Python/DialectSMT.cpp b/mlir/lib/Bindings/Python/DialectSMT.cpp new file mode 100644 index 000000000..4e7647729 --- /dev/null +++ b/mlir/lib/Bindings/Python/DialectSMT.cpp @@ -0,0 +1,83 @@ +//===- DialectSMT.cpp - Pybind module for SMT dialect API support ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "NanobindUtils.h" + +#include "mlir-c/Dialect/SMT.h" +#include "mlir-c/IR.h" +#include "mlir-c/Support.h" +#include "mlir-c/Target/ExportSMTLIB.h" +#include "mlir/Bindings/Python/Diagnostics.h" +#include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" + +namespace nb = nanobind; + +using namespace nanobind::literals; + +using namespace mlir; +using namespace mlir::python; +using namespace mlir::python::nanobind_adaptors; + +void populateDialectSMTSubmodule(nanobind::module_ &m) { + + auto smtBoolType = mlir_type_subclass(m, "BoolType", mlirSMTTypeIsABool) + .def_classmethod( + "get", + [](const nb::object &, MlirContext context) { + return mlirSMTTypeGetBool(context); + }, + "cls"_a, "context"_a.none() = nb::none()); + auto smtBitVectorType = + mlir_type_subclass(m, "BitVectorType", mlirSMTTypeIsABitVector) + .def_classmethod( + "get", + [](const nb::object &, int32_t width, MlirContext context) { + return mlirSMTTypeGetBitVector(context, width); + }, + "cls"_a, "width"_a, "context"_a.none() = nb::none()); + + auto exportSMTLIB = [](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + mlir::python::CollectDiagnosticsToStringScope scope( + mlirOperationGetContext(module)); + PyPrintAccumulator printAccum; + MlirLogicalResult result = mlirTranslateOperationToSMTLIB( + module, printAccum.getCallback(), printAccum.getUserData(), + inlineSingleUseValues, indentLetBody); + if (mlirLogicalResultIsSuccess(result)) + return printAccum.join(); + throw nb::value_error( + ("Failed to export smtlib.\nDiagnostic message " + scope.takeMessage()) + .c_str()); + }; + + m.def( + "export_smtlib", + [&exportSMTLIB](MlirOperation module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(module, inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); + m.def( + "export_smtlib", + [&exportSMTLIB](MlirModule module, bool inlineSingleUseValues, + bool indentLetBody) { + return exportSMTLIB(mlirModuleGetOperation(module), + inlineSingleUseValues, indentLetBody); + }, + "module"_a, "inline_single_use_values"_a = false, + "indent_let_body"_a = false); +} + +NB_MODULE(_mlirDialectsSMT, m) { + m.doc() = "MLIR SMT Dialect"; + + populateDialectSMTSubmodule(m); +} diff --git a/mlir/lib/CAPI/Dialect/SMT.cpp b/mlir/lib/CAPI/Dialect/SMT.cpp index 3a4620df8..7e96bbb07 100644 --- a/mlir/lib/CAPI/Dialect/SMT.cpp +++ b/mlir/lib/CAPI/Dialect/SMT.cpp @@ -25,46 +25,49 @@ MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(SMT, smt, mlir::smt::SMTDialect) // Type API. //===----------------------------------------------------------------------===// -bool smtTypeIsAnyNonFuncSMTValueType(MlirType type) { +bool mlirSMTTypeIsAnyNonFuncSMTValueType(MlirType type) { return isAnyNonFuncSMTValueType(unwrap(type)); } -bool smtTypeIsAnySMTValueType(MlirType type) { +bool mlirSMTTypeIsAnySMTValueType(MlirType type) { return isAnySMTValueType(unwrap(type)); } -bool smtTypeIsAArray(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsAArray(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetArray(MlirContext ctx, MlirType domainType, - MlirType rangeType) { +MlirType mlirSMTTypeGetArray(MlirContext ctx, MlirType domainType, + MlirType rangeType) { return wrap( ArrayType::get(unwrap(ctx), unwrap(domainType), unwrap(rangeType))); } -bool smtTypeIsABitVector(MlirType type) { +bool mlirSMTTypeIsABitVector(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetBitVector(MlirContext ctx, int32_t width) { +MlirType mlirSMTTypeGetBitVector(MlirContext ctx, int32_t width) { return wrap(BitVectorType::get(unwrap(ctx), width)); } -bool smtTypeIsABool(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsABool(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetBool(MlirContext ctx) { +MlirType mlirSMTTypeGetBool(MlirContext ctx) { return wrap(BoolType::get(unwrap(ctx))); } -bool smtTypeIsAInt(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsAInt(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetInt(MlirContext ctx) { +MlirType mlirSMTTypeGetInt(MlirContext ctx) { return wrap(IntType::get(unwrap(ctx))); } -bool smtTypeIsASMTFunc(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsASMTFunc(MlirType type) { + return isa(unwrap(type)); +} -MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, - const MlirType *domainTypes, MlirType rangeType) { +MlirType mlirSMTTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, + const MlirType *domainTypes, + MlirType rangeType) { SmallVector domainTypesVec; domainTypesVec.reserve(numberOfDomainTypes); @@ -74,10 +77,11 @@ MlirType smtTypeGetSMTFunc(MlirContext ctx, size_t numberOfDomainTypes, return wrap(SMTFuncType::get(unwrap(ctx), domainTypesVec, unwrap(rangeType))); } -bool smtTypeIsASort(MlirType type) { return isa(unwrap(type)); } +bool mlirSMTTypeIsASort(MlirType type) { return isa(unwrap(type)); } -MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier, - size_t numberOfSortParams, const MlirType *sortParams) { +MlirType mlirSMTTypeGetSort(MlirContext ctx, MlirIdentifier identifier, + size_t numberOfSortParams, + const MlirType *sortParams) { SmallVector sortParamsVec; sortParamsVec.reserve(numberOfSortParams); @@ -91,31 +95,31 @@ MlirType smtTypeGetSort(MlirContext ctx, MlirIdentifier identifier, // Attribute API. //===----------------------------------------------------------------------===// -bool smtAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { +bool mlirSMTAttrCheckBVCmpPredicate(MlirContext ctx, MlirStringRef str) { return symbolizeBVCmpPredicate(unwrap(str)).has_value(); } -bool smtAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { +bool mlirSMTAttrCheckIntPredicate(MlirContext ctx, MlirStringRef str) { return symbolizeIntPredicate(unwrap(str)).has_value(); } -bool smtAttrIsASMTAttribute(MlirAttribute attr) { +bool mlirSMTAttrIsASMTAttribute(MlirAttribute attr) { return isa(unwrap(attr)); } -MlirAttribute smtAttrGetBitVector(MlirContext ctx, uint64_t value, - unsigned width) { +MlirAttribute mlirSMTAttrGetBitVector(MlirContext ctx, uint64_t value, + unsigned width) { return wrap(BitVectorAttr::get(unwrap(ctx), value, width)); } -MlirAttribute smtAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { +MlirAttribute mlirSMTAttrGetBVCmpPredicate(MlirContext ctx, MlirStringRef str) { auto predicate = symbolizeBVCmpPredicate(unwrap(str)); assert(predicate.has_value() && "invalid predicate"); return wrap(BVCmpPredicateAttr::get(unwrap(ctx), predicate.value())); } -MlirAttribute smtAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { +MlirAttribute mlirSMTAttrGetIntPredicate(MlirContext ctx, MlirStringRef str) { auto predicate = symbolizeIntPredicate(unwrap(str)); assert(predicate.has_value() && "invalid predicate"); diff --git a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp index c9ac7ce70..4326f9672 100644 --- a/mlir/lib/CAPI/Target/ExportSMTLIB.cpp +++ b/mlir/lib/CAPI/Target/ExportSMTLIB.cpp @@ -19,9 +19,24 @@ using namespace mlir; -MlirLogicalResult mlirExportSMTLIB(MlirModule module, - MlirStringCallback callback, - void *userData) { +MlirLogicalResult mlirTranslateOperationToSMTLIB(MlirOperation module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { mlir::detail::CallbackOstream stream(callback, userData); + smt::SMTEmissionOptions options; + options.inlineSingleUseValues = inlineSingleUseValues; + options.indentLetBody = indentLetBody; return wrap(smt::exportSMTLIB(unwrap(module), stream)); } + +MlirLogicalResult mlirTranslateModuleToSMTLIB(MlirModule module, + MlirStringCallback callback, + void *userData, + bool inlineSingleUseValues, + bool indentLetBody) { + return mlirTranslateOperationToSMTLIB(mlirModuleGetOperation(module), + callback, userData, + inlineSingleUseValues, indentLetBody); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index fb115a5f4..bbf681960 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -403,6 +403,15 @@ declare_mlir_dialect_python_bindings( "../../include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td" ) +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT MLIRPythonSources.Dialects + ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/SMTOps.td + GEN_ENUM_BINDINGS + SOURCES + dialects/smt.py + DIALECT_NAME smt) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" @@ -664,6 +673,21 @@ declare_mlir_python_extension(MLIRPythonExtension.LinalgPasses MLIRCAPILinalg ) +declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind + MODULE_NAME _mlirDialectsSMT + ADD_TO_PARENT MLIRPythonSources.Dialects.smt + ROOT_DIR "${PYTHON_SOURCE_DIR}" + PYTHON_BINDINGS_LIBRARY nanobind + SOURCES + DialectSMT.cpp + PRIVATE_LINK_LIBS + LLVMSupport + EMBED_CAPI_LINK_LIBS + MLIRCAPIIR + MLIRCAPISMT + MLIRCAPIExportSMTLIB +) + declare_mlir_python_extension(MLIRPythonExtension.SparseTensorDialectPasses MODULE_NAME _mlirSparseTensorPasses ADD_TO_PARENT MLIRPythonSources.Dialects.sparse_tensor diff --git a/mlir/python/mlir/dialects/SMTOps.td b/mlir/python/mlir/dialects/SMTOps.td new file mode 100644 index 000000000..e143f071e --- /dev/null +++ b/mlir/python/mlir/dialects/SMTOps.td @@ -0,0 +1,14 @@ +//===- SMTOps.td - Entry point for SMT bindings ------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef BINDINGS_PYTHON_SMT_OPS +#define BINDINGS_PYTHON_SMT_OPS + +include "mlir/Dialect/SMT/IR/SMT.td" + +#endif // BINDINGS_PYTHON_SMT_OPS diff --git a/mlir/python/mlir/dialects/smt.py b/mlir/python/mlir/dialects/smt.py new file mode 100644 index 000000000..ae7a4c41c --- /dev/null +++ b/mlir/python/mlir/dialects/smt.py @@ -0,0 +1,33 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from ._smt_ops_gen import * + +from .._mlir_libs._mlirDialectsSMT import * +from ..extras.meta import region_op + + +def bool_t(): + return BoolType.get() + + +def bv_t(width): + return BitVectorType.get(width) + + +def _solver( + inputs=None, + results=None, + loc=None, + ip=None, +): + if inputs is None: + inputs = [] + if results is None: + results = [] + + return SolverOp(results, inputs, loc=loc, ip=ip) + + +solver = region_op(_solver, terminator=YieldOp) From 85caa7f20e84de4a5d9221083ffb66f8382fbc34 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Thu, 17 Apr 2025 01:31:19 -0400 Subject: [PATCH 875/915] [mlir][python][smt] fix DialectSMT (include NanobindUtils.h) --- mlir/python/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index bbf681960..e3934fc9f 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -680,6 +680,8 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.SMT.Pybind PYTHON_BINDINGS_LIBRARY nanobind SOURCES DialectSMT.cpp + # Headers must be included explicitly so they are installed. + NanobindUtils.h PRIVATE_LINK_LIBS LLVMSupport EMBED_CAPI_LINK_LIBS From 476cf1c31c459a4daaeea2fd530aafb3625789ef Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 17 Apr 2025 16:52:36 -0400 Subject: [PATCH 876/915] [MLIR][CAPI][python] expose the python binding for linalgOp.getIndexingMaps (#136054) This PR is mainly about exposing the python bindings for `linalgOp.getIndexingMaps`. --------- Signed-off-by: Bangtian Liu --- mlir/include/mlir-c/Dialect/Linalg.h | 3 +++ mlir/lib/Bindings/Python/DialectLinalg.cpp | 10 ++++++++++ mlir/lib/CAPI/Dialect/Linalg.cpp | 10 ++++++++++ 3 files changed, 23 insertions(+) diff --git a/mlir/include/mlir-c/Dialect/Linalg.h b/mlir/include/mlir-c/Dialect/Linalg.h index 4f2ee0d43..339e63d66 100644 --- a/mlir/include/mlir-c/Dialect/Linalg.h +++ b/mlir/include/mlir-c/Dialect/Linalg.h @@ -50,6 +50,9 @@ typedef struct MlirLinalgConvolutionDimensions { MLIR_CAPI_EXPORTED MlirLinalgConvolutionDimensions mlirLinalgInferConvolutionDimensions(MlirOperation op); +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op); + MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Linalg, linalg); #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/DialectLinalg.cpp b/mlir/lib/Bindings/Python/DialectLinalg.cpp index ce1102a3b..015502371 100644 --- a/mlir/lib/Bindings/Python/DialectLinalg.cpp +++ b/mlir/lib/Bindings/Python/DialectLinalg.cpp @@ -120,6 +120,16 @@ static void populateDialectLinalgSubmodule(nb::module_ m) { m.def("infer_convolution_dimensions", &InferConvolutionDimensions, "Infers convolution dimensions", nb::arg("op")); + + m.def( + "get_indexing_maps", + [](MlirOperation op) -> std::optional { + MlirAttribute attr = mlirLinalgGetIndexingMapsAttribute(op); + if (mlirAttributeIsNull(attr)) + return std::nullopt; + return attr; + }, + "Returns the indexing_maps attribute for a linalg op."); } NB_MODULE(_mlirDialectsLinalg, m) { diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 7c456102a..0c4f6e88e 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -120,4 +120,14 @@ mlirLinalgInferConvolutionDimensions(MlirOperation op) { return result; } +MLIR_CAPI_EXPORTED MlirAttribute +mlirLinalgGetIndexingMapsAttribute(MlirOperation op) { + auto linalgOp = llvm::dyn_cast(unwrap(op)); + if (!linalgOp) + return MlirAttribute{nullptr}; + + ArrayAttr attr = linalgOp.getIndexingMaps(); + return wrap(attr); +} + MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Linalg, linalg, LinalgDialect) From 911ecc180e2df925c9f910cdec12a97008c72e75 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Wed, 23 Apr 2025 02:53:35 +0100 Subject: [PATCH 877/915] [MLIR] [python] Fixed the signature of `_OperationBase.get_asm` (#136676) It claimed to return an `io.StringIO` or an `io.BytesIO`, but it did in fact return `str` or `bytes`. --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 31 +++++++++++++++--------- 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 1c8080c5d..6c5f91d75 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -45,9 +45,8 @@ from __future__ import annotations import abc import collections from collections.abc import Callable, Sequence -import io from pathlib import Path -from typing import Any, BinaryIO, ClassVar, TypeVar, overload +from typing import Any, BinaryIO, ClassVar, Literal, TypeVar, overload __all__ = [ "AffineAddExpr", @@ -196,6 +195,19 @@ class _OperationBase: Detaches the operation from its parent block. """ def erase(self) -> None: ... + + @overload + def get_asm( + binary: Literal[True], + large_elements_limit: int | None = None, + enable_debug_info: bool = False, + pretty_debug_info: bool = False, + print_generic_op_form: bool = False, + use_local_scope: bool = False, + assume_verified: bool = False, + skip_regions: bool = False, + ) -> bytes: ... + @overload def get_asm( self, binary: bool = False, @@ -206,19 +218,14 @@ class _OperationBase: use_local_scope: bool = False, assume_verified: bool = False, skip_regions: bool = False, - ) -> io.BytesIO | io.StringIO: + ) -> str: """ - Gets the assembly form of the operation with all options available. + Returns the assembly form of the operation. - Args: - binary: Whether to return a bytes (True) or str (False) object. Defaults to - False. - ... others ...: See the print() method for common keyword arguments for - configuring the printout. - Returns: - Either a bytes or str object, depending on the setting of the 'binary' - argument. + See the print() method for common keyword arguments for configuring + the output. """ + def move_after(self, other: _OperationBase) -> None: """ Puts self immediately after the other operation in its parent block. From 4d031045f27b2414dd2a31e191b36d9ad3c31ea9 Mon Sep 17 00:00:00 2001 From: Rolf Morel <854835+rolfmorel@users.noreply.github.com> Date: Mon, 12 May 2025 11:34:55 +0200 Subject: [PATCH 878/915] [MLIR][Linalg][Python] Improve bindings for linalg.elementwise (#139462) Adds wrappers for ElementWiseOp, in particular to ensure appropriate default indexing maps are derived. --- mlir/python/mlir/dialects/linalg/__init__.py | 61 ++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index 63586a5bb..a5a659abb 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -216,6 +216,67 @@ def contract( ) +# Extend and shadow the TableGen-derived version to make sure correct default +# indexing_maps are derived (as there is no mechanism for doing so given the +# Python API bypasses the C++-builders). +class ElementwiseOp_(ElementwiseOp): + def __init__( + self, + result_tensors, + inputs, + outputs, + kind, + *, + indexing_maps=None, + loc=None, + ip=None, + ): + if indexing_maps is None: + inputs = [_get_op_result_or_value(in_) for in_ in inputs] + for in0, in1 in zip(inputs[:-1], inputs[1:]): + assert in0.type == in1.type + output = _get_op_result_or_value(outputs[0]) + assert inputs[0].type == output.type + num_args = len(inputs) + 1 + indexing_maps = [AffineMap.get_identity(output.type.rank)] * num_args + + super().__init__( + result_tensors=result_tensors, + inputs=inputs, + outputs=outputs, + kind=kind, + indexing_maps=indexing_maps, + loc=loc, + ip=ip, + ) + + +ElementwiseOp = ElementwiseOp_ + + +def elementwise( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + kind: Union[ElementwiseKind, Attribute], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, +): + ins = [_get_op_result_or_value(input) for input in ins] + if len(outs) != 1: + raise ValueError(f"{outs=} must have length 1.") + init = _get_op_result_or_value(outs[0]) + result_types = [init.type] if isinstance(init.type, RankedTensorType) else [] + + op = ElementwiseOp( + result_tensors=result_types, + inputs=ins, + outputs=[init], + kind=kind, + indexing_maps=indexing_maps, + ) + fill_builtin_region(op.operation) + return _get_op_result_or_op_results(op) + + def pack( source, dest, From 83a627f7914e5f45e13a90339f776e5667c21eee Mon Sep 17 00:00:00 2001 From: Md Asghar Ahmad Shahid Date: Mon, 12 May 2025 17:59:34 +0530 Subject: [PATCH 879/915] =?UTF-8?q?[MLIR][Linalg]=20Introduce=20transpose/?= =?UTF-8?q?broadcast=20semantic=20to=20linalg.batch=E2=80=A6=20(#130944)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …_reduce_matmul. This patch exposes broadcast and transpose semantics on 'batch_reduce_matmul'. This is the last one in continuation of other two variant of matmul ops. The broadcast and transpose semantic are as follows: Broadcast and Transpose semantics can be appiled by specifying the explicit attribute 'indexing_maps' as shown below. This is a list attribute, so must include maps for all arguments if specified. Example Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d3, d1)>, // transpose affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<2x5x3xf32>,memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<5xf32>, memref<2x5x7xf32>) outs(%arg2: memref<3x7xf32>) ``` Example Broadcast and Transpose: ``` linalg.batch_reduce_matmul indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d1, d3)>, // broadcast affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, // transpose affine_map<(d0, d1, d2, d3) -> (d1, d2)> ] ins(%arg0, %arg1 : memref<3x5xf32>, memref<2x7x5xf32>) outs(%arg2: memref<3x7xf32>) ``` RFCs and related PR: https://discourse.llvm.org/t/rfc-linalg-opdsl-constant-list-attribute-definition/80149 https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863 https://discourse.llvm.org/t/rfc-mlir-linalg-operation-tree/83586 https://github.com/llvm/llvm-project/pull/115319 https://github.com/llvm/llvm-project/pull/122275 --- mlir/python/mlir/dialects/linalg/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py index a5a659abb..d387c12de 100644 --- a/mlir/python/mlir/dialects/linalg/__init__.py +++ b/mlir/python/mlir/dialects/linalg/__init__.py @@ -203,6 +203,19 @@ def batch_matmul( ) +def batch_reduce_matmul( + *ins: Union[Operation, OpView, Value], + outs: Sequence[Union[Operation, OpView, Value]], + indexing_maps: Optional[Sequence[AffineMapAttr]] = None, + cast: Optional[Union[TypeFn, Attribute]] = None, +): + return _get_op_result_or_op_results( + _create_matmul_like_op( + BatchReduceMatmulOp, *ins, outs=outs, indexing_maps=indexing_maps, cast=cast + ) + ) + + def contract( *ins: Union[Operation, OpView, Value], outs: Sequence[Union[Operation, OpView, Value]], From 4e7684c42559cc6fa72c1cf5c54624430c38a614 Mon Sep 17 00:00:00 2001 From: drazi Date: Thu, 29 May 2025 12:14:37 +0800 Subject: [PATCH 880/915] assert with more information to help debug (#132194) This PR output debug message to assertion to help debug user python code. Will print out more friendly information ``` > assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}" E AssertionError: expected Value, got ``` --- mlir/python/mlir/dialects/_ods_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index d3dbdc604..a5efa057c 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -105,7 +105,7 @@ def get_op_result_or_value( elif isinstance(arg, _cext.ir.OpResultList): return arg[0] else: - assert isinstance(arg, _cext.ir.Value) + assert isinstance(arg, _cext.ir.Value), f"expects Value, got {type(arg)}" return arg @@ -147,6 +147,7 @@ def get_op_result_or_op_results( else: return op + ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] From 0d04f1c9e8a22cd847f96a02be670a23142fce50 Mon Sep 17 00:00:00 2001 From: Rolf Morel <854835+rolfmorel@users.noreply.github.com> Date: Wed, 11 Jun 2025 17:33:55 +0100 Subject: [PATCH 881/915] [MLIR][Transform] apply_registered_pass op's options as a dict (#143159) Improve ApplyRegisteredPassOp's support for taking options by taking them as a dict (vs a list of string-valued key-value pairs). Values of options are provided as either static attributes or as params (which pass in attributes at interpreter runtime). In either case, the keys and value attributes are converted to strings and a single options-string, in the format used on the commandline, is constructed to pass to the `addToPipeline`-pass API. --- .../mlir/dialects/transform/__init__.py | 82 ++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 5b158ec6b..10a04b0cc 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -18,7 +18,12 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union, NewType +from typing import Dict, Optional, Sequence, Union, NewType + + +@register_attribute_builder("ParamOperandAttr") +def _paramOperandAttr(x: int, context) -> Attribute: + return Attribute.parse(f"#transform.param_operand", context=context) @_ods_cext.register_operation(_Dialect, replace=True) @@ -214,6 +219,81 @@ def __init__( super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) +@_ods_cext.register_operation(_Dialect, replace=True) +class ApplyRegisteredPassOp(ApplyRegisteredPassOp): + def __init__( + self, + result: Type, + pass_name: Union[str, StringAttr], + target: Union[Operation, Value, OpView], + *, + options: Optional[ + Dict[ + Union[str, StringAttr], + Union[Attribute, Value, Operation, OpView], + ] + ] = None, + loc=None, + ip=None, + ): + options_dict = {} + dynamic_options = [] + + ParamOperandAttr = AttrBuilder.get("ParamOperandAttr") + context = (loc and loc.context) or Context.current + + cur_param_operand_idx = 0 + for key, value in options.items() if options is not None else {}: + if isinstance(key, StringAttr): + key = key.value + + if isinstance(value, (Value, Operation, OpView)): + dynamic_options.append(_get_op_result_or_value(value)) + options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context) + cur_param_operand_idx += 1 + elif isinstance(value, Attribute): + options_dict[key] = value + elif isinstance(value, str): + options_dict[key] = StringAttr.get(value) + else: + raise TypeError(f"Unsupported option type: {type(value)}") + if len(options_dict) > 0: + print(options_dict, cur_param_operand_idx) + super().__init__( + result, + pass_name, + dynamic_options, + target=_get_op_result_or_value(target), + options=DictAttr.get(options_dict), + loc=loc, + ip=ip, + ) + + +def apply_registered_pass( + result: Type, + pass_name: Union[str, StringAttr], + target: Union[Operation, Value, OpView], + *, + options: Optional[ + Dict[ + Union[str, StringAttr], + Union[Attribute, Value, Operation, OpView], + ] + ] = None, + loc=None, + ip=None, +) -> Value: + return ApplyRegisteredPassOp( + result=result, + pass_name=pass_name, + target=target, + options=options, + loc=loc, + ip=ip, + ).result + + AnyOpTypeT = NewType("AnyOpType", AnyOpType) From a4b5c5a6459dd0c35b479085c8dd7433ca945b01 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 11 Jun 2025 21:19:52 +0100 Subject: [PATCH 882/915] [MLIR][Transform] apply_registered_op fixes: arg order & python options auto-conversion (#143779) --- .../python/mlir/dialects/transform/__init__.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index 10a04b0cc..bfe96b1b3 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -224,13 +224,13 @@ class ApplyRegisteredPassOp(ApplyRegisteredPassOp): def __init__( self, result: Type, - pass_name: Union[str, StringAttr], target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], *, options: Optional[ Dict[ Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView], + Union[Attribute, Value, Operation, OpView, str, int, bool], ] ] = None, loc=None, @@ -253,17 +253,21 @@ def __init__( cur_param_operand_idx += 1 elif isinstance(value, Attribute): options_dict[key] = value + # The following cases auto-convert Python values to attributes. + elif isinstance(value, bool): + options_dict[key] = BoolAttr.get(value) + elif isinstance(value, int): + default_int_type = IntegerType.get_signless(64, context) + options_dict[key] = IntegerAttr.get(default_int_type, value) elif isinstance(value, str): options_dict[key] = StringAttr.get(value) else: raise TypeError(f"Unsupported option type: {type(value)}") - if len(options_dict) > 0: - print(options_dict, cur_param_operand_idx) super().__init__( result, + _get_op_result_or_value(target), pass_name, dynamic_options, - target=_get_op_result_or_value(target), options=DictAttr.get(options_dict), loc=loc, ip=ip, @@ -272,13 +276,13 @@ def __init__( def apply_registered_pass( result: Type, - pass_name: Union[str, StringAttr], target: Union[Operation, Value, OpView], + pass_name: Union[str, StringAttr], *, options: Optional[ Dict[ Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView], + Union[Attribute, Value, Operation, OpView, str, int, bool], ] ] = None, loc=None, From 52f38aa9b44f7eb7e1f7c5b39be1e831b3ab1a02 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Mon, 16 Jun 2025 13:40:50 +0200 Subject: [PATCH 883/915] [MLIR][Transform] apply_registered_pass: support ListOptions (#144026) Interpret an option value with multiple values, either in the form of an `ArrayAttr` (either static or passed through a param) or as the multiple attrs associated to a param, as a comma-separated list, i.e. as a ListOption on a pass. --- .../mlir/dialects/transform/__init__.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/mlir/python/mlir/dialects/transform/__init__.py b/mlir/python/mlir/dialects/transform/__init__.py index bfe96b1b3..b075919d1 100644 --- a/mlir/python/mlir/dialects/transform/__init__.py +++ b/mlir/python/mlir/dialects/transform/__init__.py @@ -219,6 +219,11 @@ def __init__( super().__init__(_get_op_results_or_values(operands), loc=loc, ip=ip) +OptionValueTypes = Union[ + Sequence["OptionValueTypes"], Attribute, Value, Operation, OpView, str, int, bool +] + + @_ods_cext.register_operation(_Dialect, replace=True) class ApplyRegisteredPassOp(ApplyRegisteredPassOp): def __init__( @@ -227,12 +232,7 @@ def __init__( target: Union[Operation, Value, OpView], pass_name: Union[str, StringAttr], *, - options: Optional[ - Dict[ - Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView, str, int, bool], - ] - ] = None, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, loc=None, ip=None, ): @@ -243,26 +243,32 @@ def __init__( context = (loc and loc.context) or Context.current cur_param_operand_idx = 0 - for key, value in options.items() if options is not None else {}: - if isinstance(key, StringAttr): - key = key.value + def option_value_to_attr(value): + nonlocal cur_param_operand_idx if isinstance(value, (Value, Operation, OpView)): dynamic_options.append(_get_op_result_or_value(value)) - options_dict[key] = ParamOperandAttr(cur_param_operand_idx, context) cur_param_operand_idx += 1 + return ParamOperandAttr(cur_param_operand_idx - 1, context) elif isinstance(value, Attribute): - options_dict[key] = value + return value # The following cases auto-convert Python values to attributes. elif isinstance(value, bool): - options_dict[key] = BoolAttr.get(value) + return BoolAttr.get(value) elif isinstance(value, int): default_int_type = IntegerType.get_signless(64, context) - options_dict[key] = IntegerAttr.get(default_int_type, value) + return IntegerAttr.get(default_int_type, value) elif isinstance(value, str): - options_dict[key] = StringAttr.get(value) + return StringAttr.get(value) + elif isinstance(value, Sequence): + return ArrayAttr.get([option_value_to_attr(elt) for elt in value]) else: raise TypeError(f"Unsupported option type: {type(value)}") + + for key, value in options.items() if options is not None else {}: + if isinstance(key, StringAttr): + key = key.value + options_dict[key] = option_value_to_attr(value) super().__init__( result, _get_op_result_or_value(target), @@ -279,12 +285,7 @@ def apply_registered_pass( target: Union[Operation, Value, OpView], pass_name: Union[str, StringAttr], *, - options: Optional[ - Dict[ - Union[str, StringAttr], - Union[Attribute, Value, Operation, OpView, str, int, bool], - ] - ] = None, + options: Optional[Dict[Union[str, StringAttr], OptionValueTypes]] = None, loc=None, ip=None, ) -> Value: From 927533ebb827e71ff4dead5df75b9e9bc32e0acc Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Wed, 18 Jun 2025 14:53:20 +0100 Subject: [PATCH 884/915] [mlir] [python] Fixed the return type of `MemRefType.get_strides_and_offset` (#144523) Previously, the return type for `offset` was `list[int]`, which clearly is not right. --- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 6c5f91d75..70bca3c75 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -2119,7 +2119,7 @@ class MemRefType(ShapedType): """ @property def typeid(self) -> TypeID: ... - def get_strides_and_offset(self) -> tuple[list[int], list[int]]: + def get_strides_and_offset(self) -> tuple[list[int], int]: """ The strides and offset of the MemRef type. """ From aa81ade2fa4651bfec66398d9df22a0de42e22ef Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 20 Jun 2025 14:34:43 -0500 Subject: [PATCH 885/915] [mlir][python] expose operation.block (#145088) Expose `operation-getBlock()` in python. --- mlir/lib/Bindings/Python/IRCore.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index b5720b7ad..cbd35f297 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3385,6 +3385,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyOperation::createFromCapsule) .def_prop_ro("operation", [](nb::object self) { return self; }) .def_prop_ro("opview", &PyOperation::createOpView) + .def_prop_ro("block", &PyOperation::getBlock) .def_prop_ro( "successors", [](PyOperationBase &self) { From 6f9060daf01985cf371a6fd6c30fa1926e63a6a4 Mon Sep 17 00:00:00 2001 From: Kazu Hirata Date: Sat, 21 Jun 2025 08:20:49 -0700 Subject: [PATCH 886/915] [mlir] Migrate away from ArrayRef(std::nullopt) (NFC) (#145140) ArrayRef has a constructor that accepts std::nullopt. This constructor dates back to the days when we still had llvm::Optional. Since the use of std::nullopt outside the context of std::optional is kind of abuse and not intuitive to new comers, I would like to move away from the constructor and eventually remove it. This patch takes care of the mlir side of the migration, starting with straightforward places like "return std::nullopt;" and ternally expressions involving std::nullopt. --- mlir/include/mlir/CAPI/Wrap.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/include/mlir/CAPI/Wrap.h b/mlir/include/mlir/CAPI/Wrap.h index 5b68f417a..fd5b6e18d 100644 --- a/mlir/include/mlir/CAPI/Wrap.h +++ b/mlir/include/mlir/CAPI/Wrap.h @@ -44,7 +44,7 @@ static llvm::ArrayRef unwrapList(size_t size, CTy *first, "incompatible C and C++ types"); if (size == 0) - return std::nullopt; + return {}; assert(storage.empty() && "expected to populate storage"); storage.reserve(size); From 72f0eb5de3dfbdfb4d64444dc7c60e3bd8f58939 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 23 Jun 2025 14:49:01 -0500 Subject: [PATCH 887/915] [mlir][python] add `MLIR_BINDINGS_PYTHON_INSTALL_PREFIX` to make bindings install dir configurable (#124878) This PR parameterizes the install directory of the MLIR Python bindings in the final distribution. --- mlir/python/CMakeLists.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index e3934fc9f..ee0708124 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -356,7 +356,7 @@ declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/EmitC.td - SOURCES + SOURCES dialects/emitc.py DIALECT_NAME emitc) @@ -790,7 +790,7 @@ endif() add_mlir_python_common_capi_library(MLIRPythonCAPI INSTALL_COMPONENT MLIRPythonModules - INSTALL_DESTINATION python_packages/mlir_core/mlir/_mlir_libs + INSTALL_DESTINATION "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}/_mlir_libs" OUTPUT_DIRECTORY "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir/_mlir_libs" RELATIVE_INSTALL_ROOT "../../../.." DECLARED_HEADERS @@ -821,7 +821,7 @@ endif() add_mlir_python_modules(MLIRPythonModules ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" - INSTALL_PREFIX "python_packages/mlir_core/mlir" + INSTALL_PREFIX "${MLIR_BINDINGS_PYTHON_INSTALL_PREFIX}" DECLARED_SOURCES MLIRPythonSources MLIRPythonExtension.RegisterEverything From 78fbf0c656e4edbe00a134e9cee15fffaa13781f Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 23 Jun 2025 18:59:03 -0500 Subject: [PATCH 888/915] [mlir][python] bind block predecessors and successors (#145116) bind `block.getSuccessor` and `block.getPredecessors`. --- mlir/include/mlir-c/IR.h | 18 ++++++ mlir/lib/Bindings/Python/IRCore.cpp | 98 ++++++++++++++++++++++++++++- mlir/lib/CAPI/IR/IR.cpp | 20 ++++++ 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 1a8e8737f..81299c791 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -986,6 +986,24 @@ MLIR_CAPI_EXPORTED MlirValue mlirBlockGetArgument(MlirBlock block, MLIR_CAPI_EXPORTED void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, void *userData); +/// Returns the number of successor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumSuccessors(MlirBlock block); + +/// Returns `pos`-th successor of the block. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetSuccessor(MlirBlock block, + intptr_t pos); + +/// Returns the number of predecessor blocks of the block. +MLIR_CAPI_EXPORTED intptr_t mlirBlockGetNumPredecessors(MlirBlock block); + +/// Returns `pos`-th predecessor of the block. +/// +/// WARNING: This getter is more expensive than the others here because +/// the impl actually iterates the use-def chain (of block operands) anew for +/// each indexed access. +MLIR_CAPI_EXPORTED MlirBlock mlirBlockGetPredecessor(MlirBlock block, + intptr_t pos); + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index cbd35f297..d96148288 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2626,6 +2626,88 @@ class PyOpSuccessors : public Sliceable { PyOperationRef operation; }; +/// A list of block successors. Internally, these are stored as consecutive +/// elements, random access is cheap. The (returned) successor list is +/// associated with the operation and block whose successors these are, and thus +/// extends the lifetime of this operation and block. +class PyBlockSuccessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockSuccessors"; + + PyBlockSuccessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumSuccessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumSuccessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetSuccessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockSuccessors slice(intptr_t startIndex, intptr_t length, intptr_t step) { + return PyBlockSuccessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + +/// A list of block predecessors. The (returned) predecessor list is +/// associated with the operation and block whose predecessors these are, and +/// thus extends the lifetime of this operation and block. +/// +/// WARNING: This Sliceable is more expensive than the others here because +/// mlirBlockGetPredecessor actually iterates the use-def chain (of block +/// operands) anew for each indexed access. +class PyBlockPredecessors : public Sliceable { +public: + static constexpr const char *pyClassName = "BlockPredecessors"; + + PyBlockPredecessors(PyBlock block, PyOperationRef operation, + intptr_t startIndex = 0, intptr_t length = -1, + intptr_t step = 1) + : Sliceable(startIndex, + length == -1 ? mlirBlockGetNumPredecessors(block.get()) + : length, + step), + operation(operation), block(block) {} + +private: + /// Give the parent CRTP class access to hook implementations below. + friend class Sliceable; + + intptr_t getRawNumElements() { + block.checkValid(); + return mlirBlockGetNumPredecessors(block.get()); + } + + PyBlock getRawElement(intptr_t pos) { + MlirBlock block = mlirBlockGetPredecessor(this->block.get(), pos); + return PyBlock(operation, block); + } + + PyBlockPredecessors slice(intptr_t startIndex, intptr_t length, + intptr_t step) { + return PyBlockPredecessors(block, operation, startIndex, length, step); + } + + PyOperationRef operation; + PyBlock block; +}; + /// A list of operation attributes. Can be indexed by name, producing /// attributes, or by index, producing named attributes. class PyOpAttributeMap { @@ -3655,7 +3737,19 @@ void mlir::python::populateIRCore(nb::module_ &m) { }, nb::arg("operation"), "Appends an operation to this block. If the operation is currently " - "in another block, it will be moved."); + "in another block, it will be moved.") + .def_prop_ro( + "successors", + [](PyBlock &self) { + return PyBlockSuccessors(self, self.getParentOperation()); + }, + "Returns the list of Block successors.") + .def_prop_ro( + "predecessors", + [](PyBlock &self) { + return PyBlockPredecessors(self, self.getParentOperation()); + }, + "Returns the list of Block predecessors."); //---------------------------------------------------------------------------- // Mapping of PyInsertionPoint. @@ -4099,6 +4193,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyBlockArgumentList::bind(m); PyBlockIterator::bind(m); PyBlockList::bind(m); + PyBlockSuccessors::bind(m); + PyBlockPredecessors::bind(m); PyOperationIterator::bind(m); PyOperationList::bind(m); PyOpAttributeMap::bind(m); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index e0e386d55..fbc66bcf5 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -1059,6 +1059,26 @@ void mlirBlockPrint(MlirBlock block, MlirStringCallback callback, unwrap(block)->print(stream); } +intptr_t mlirBlockGetNumSuccessors(MlirBlock block) { + return static_cast(unwrap(block)->getNumSuccessors()); +} + +MlirBlock mlirBlockGetSuccessor(MlirBlock block, intptr_t pos) { + return wrap(unwrap(block)->getSuccessor(static_cast(pos))); +} + +intptr_t mlirBlockGetNumPredecessors(MlirBlock block) { + Block *b = unwrap(block); + return static_cast(std::distance(b->pred_begin(), b->pred_end())); +} + +MlirBlock mlirBlockGetPredecessor(MlirBlock block, intptr_t pos) { + Block *b = unwrap(block); + Block::pred_iterator it = b->pred_begin(); + std::advance(it, pos); + return wrap(*it); +} + //===----------------------------------------------------------------------===// // Value API. //===----------------------------------------------------------------------===// From 74426707ee6cb5608d7ced85358487230f424246 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Wed, 25 Jun 2025 17:39:01 +0200 Subject: [PATCH 889/915] [MLIR][Transform] expose transform.debug extension in Python (#145550) Removes the Debug... prefix on the ops in tablegen, in line with pretty much all other Transform-dialect extension ops. This means that the ops in Python look like `debug.EmitParamAsRemarkOp`/`debug.emit_param_as_remark` instead of `debug.DebugEmitParamAsRemarkOp`/`debug.debug_emit_param_as_remark`. --- mlir/python/CMakeLists.txt | 9 +++ .../dialects/TransformDebugExtensionOps.td | 19 +++++ mlir/python/mlir/dialects/transform/debug.py | 81 +++++++++++++++++++ 3 files changed, 109 insertions(+) create mode 100644 mlir/python/mlir/dialects/TransformDebugExtensionOps.td create mode 100644 mlir/python/mlir/dialects/transform/debug.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index ee0708124..b2daabb2a 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -171,6 +171,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_pdl_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformDebugExtensionOps.td + SOURCES + dialects/transform/debug.py + DIALECT_NAME transform + EXTENSION_NAME transform_debug_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformDebugExtensionOps.td b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td new file mode 100644 index 000000000..22a85d236 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformDebugExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformDebugExtensionOps.td - Binding entry point *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Debug extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS + +include "mlir/Dialect/Transform/DebugExtension/DebugExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/transform/debug.py b/mlir/python/mlir/dialects/transform/debug.py new file mode 100644 index 000000000..f7c04268d --- /dev/null +++ b/mlir/python/mlir/dialects/transform/debug.py @@ -0,0 +1,81 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional + +from ...ir import Attribute, Operation, Value, StringAttr +from .._transform_debug_extension_ops_gen import * +from .._transform_pdl_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitParamAsRemarkOp(EmitParamAsRemarkOp): + def __init__( + self, + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + param, + anchor=anchor, + message=message, + loc=loc, + ip=ip, + ) + + +def emit_param_as_remark( + param: Attribute, + *, + anchor: Optional[Operation] = None, + message: Optional[Union[StringAttr, str]] = None, + loc=None, + ip=None, +): + return EmitParamAsRemarkOp(param, anchor=anchor, message=message, loc=loc, ip=ip) + + +@_ods_cext.register_operation(_Dialect, replace=True) +class EmitRemarkAtOp(EmitRemarkAtOp): + def __init__( + self, + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, + ): + if isinstance(message, str): + message = StringAttr.get(message) + + super().__init__( + at, + message, + loc=loc, + ip=ip, + ) + + +def emit_remark_at( + at: Union[Operation, Value], + message: Optional[Union[StringAttr, str]] = None, + *, + loc=None, + ip=None, +): + return EmitRemarkAtOp(at, message, loc=loc, ip=ip) From 2fe1658d944f9281aefd6789b19bfc89bc5aeed3 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Thu, 26 Jun 2025 00:22:08 +0200 Subject: [PATCH 890/915] [MLIR][Linalg] Harden parsing Linalg named ops (#145337) This thread through proper error handling / reporting capabilities to avoid hitting llvm_unreachable while parsing linalg ops. Fixes #132755 Fixes #132740 Fixes #129185 --- mlir/lib/CAPI/Dialect/Linalg.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/CAPI/Dialect/Linalg.cpp b/mlir/lib/CAPI/Dialect/Linalg.cpp index 0c4f6e88e..21db18dfd 100644 --- a/mlir/lib/CAPI/Dialect/Linalg.cpp +++ b/mlir/lib/CAPI/Dialect/Linalg.cpp @@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) { Region ®ion = op->getRegion(0); Block *body = b.createBlock(®ion, /*insertPt=*/{}, argTypes, argLocs); b.setInsertionPointToStart(body); - fun(b, *body, op->getAttrs()); + fun(b, *body, op->getAttrs(), /*emitError=*/{}); } MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) { From eb15f0d5bf530496aedecb65b6249ce39c11cd15 Mon Sep 17 00:00:00 2001 From: Longsheng Mou Date: Mon, 7 Jul 2025 09:12:38 +0800 Subject: [PATCH 891/915] [mlir] Use `llvm::fill` instead of `std::fill`(NFC) (#146889) --- mlir/lib/Bindings/Python/IRAttributes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 12725a0ed..8f79caf08 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -1670,7 +1670,7 @@ class PyStridedLayoutAttribute [](int64_t rank, DefaultingPyMlirContext ctx) { auto dynamic = mlirShapedTypeGetDynamicStrideOrOffset(); std::vector strides(rank); - std::fill(strides.begin(), strides.end(), dynamic); + llvm::fill(strides, dynamic); MlirAttribute attr = mlirStridedLayoutAttrGet( ctx->get(), dynamic, strides.size(), strides.data()); return PyStridedLayoutAttribute(ctx->getRef(), attr); From bdbf563aedd6a87d3235de5b9be1d44cc3ac25e9 Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Mon, 7 Jul 2025 12:47:41 +0200 Subject: [PATCH 892/915] [DenseMap] Do not align pointer sentinel values (NFC) (#146595) DenseMapInfo for pointers currently uses empty/tombstone values that are aligned (by assuming a very conservative alignment). However, this means that we have to work with larger immediates. This patch proposes to use the values -1 and -2 instead, without caring about pointer alignment. (Non-roundtrip) integer to pointer casts are implementation-defined in C++, but the general implementer consensus (including Clang) is that raw pointers do not carry alignment requirements, only memory accesses do. We already have lots of places that rely on this using variations on `reinterpret_cast(-1)`, so it seems odd to insist on properly aligned pointers in this one place. It is necessary to adjust a few other places after this change, which currently assume that `DenseMapInfo` returns a highly-aligned pointer. This is a small improvement for both compile-time and clang binary size. --- mlir/lib/Bindings/Python/NanobindUtils.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 64ea4329f..535fc2328 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -408,11 +408,12 @@ namespace llvm { template <> struct DenseMapInfo { static inline MlirTypeID getEmptyKey() { - auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + // Shift by 3 to satisfy the TypeID alignment requirement. + void *pointer = reinterpret_cast(uintptr_t(-1) << 3); return mlirTypeIDCreate(pointer); } static inline MlirTypeID getTombstoneKey() { - auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + void *pointer = reinterpret_cast(uintptr_t(-2) << 3); return mlirTypeIDCreate(pointer); } static inline unsigned getHashValue(const MlirTypeID &val) { From 02063efb35cc89c7659963b2831db41459c47478 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Mon, 7 Jul 2025 12:33:55 +0100 Subject: [PATCH 893/915] [MLIR][Linalg] Remove elemwise_unary and elemwise_binary (#147082) RFC: https://discourse.llvm.org/t/rfc-deprecate-linalg-elemwise-unary-and-elemwise-binary/87144 Remove the two operations and fix the tests by: * Cleaning simple operation tests of the old ops * Changing `linalg.elemwise_{u|bi}nary` with `linalg.{exp|add}` on transform tests * Changing some of the tests with `linalg.elementwise` instead, to broaden test coverage * Surgically removing the `elemwise_*` part in the Python tests * Update MLIR transform examples (text and tests) with `linalg.elementwise` instead Nothing else changed. --- .../linalg/opdsl/ops/core_named_ops.py | 31 ------------------- 1 file changed, 31 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 48e724d80..1b359da40 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -21,21 +21,6 @@ def copy( O[None] = cast(U, I[None]) -@linalg_structured_op -def elemwise_unary( - I=TensorDef(T1), - O=TensorDef(U, output=True), - fun=UnaryFnAttrDef(default=UnaryFn.exp), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Applies the unary function fun elementwise. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, I[None])) - - @linalg_structured_op def exp( I=TensorDef(T1), @@ -192,22 +177,6 @@ def erf( O[None] = UnaryFn.erf(I[None]) -@linalg_structured_op -def elemwise_binary( - lhs=TensorDef(T1), - rhs=TensorDef(T2), - O=TensorDef(U, output=True), - fun=BinaryFnAttrDef(default=BinaryFn.add), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Applies the binary function fun elementwise. - - Numeric casting is performed on the input operand, promoting it to the same - data type as the accumulator/output. - """ - O[None] = fun(cast(U, lhs[None]), cast(U, rhs[None])) - - @linalg_structured_op def add( lhs=TensorDef(T1), From 9b02b803c4823db9f752ab2caebc19e485c8c1cb Mon Sep 17 00:00:00 2001 From: Nikita Popov Date: Mon, 7 Jul 2025 15:15:47 +0200 Subject: [PATCH 894/915] Revert "[DenseMap] Do not align pointer sentinel values (NFC) (#146595)" This reverts commit bdbf563aedd6a87d3235de5b9be1d44cc3ac25e9. This causes ubsan failures when the sentinel pointers are upcast using static_cast<>, which checks alignment. --- mlir/lib/Bindings/Python/NanobindUtils.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/NanobindUtils.h b/mlir/lib/Bindings/Python/NanobindUtils.h index 535fc2328..64ea4329f 100644 --- a/mlir/lib/Bindings/Python/NanobindUtils.h +++ b/mlir/lib/Bindings/Python/NanobindUtils.h @@ -408,12 +408,11 @@ namespace llvm { template <> struct DenseMapInfo { static inline MlirTypeID getEmptyKey() { - // Shift by 3 to satisfy the TypeID alignment requirement. - void *pointer = reinterpret_cast(uintptr_t(-1) << 3); + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); return mlirTypeIDCreate(pointer); } static inline MlirTypeID getTombstoneKey() { - void *pointer = reinterpret_cast(uintptr_t(-2) << 3); + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); return mlirTypeIDCreate(pointer); } static inline unsigned getHashValue(const MlirTypeID &val) { From 556b741646da7cfad2310d106df0e8bf3df5687b Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Mon, 7 Jul 2025 14:57:27 -0400 Subject: [PATCH 895/915] [mlir] Add `isStatic`* size check for `ShapedType`s. NFCI. (#147085) The motivation is to avoid having to negate `isDynamic*` checks, avoid double negations, and allow for `ShapedType::isStaticDim` to be used in ADT functions without having to wrap it in a lambda performing the negation. Also add the new functions to C and Python bindings. --- mlir/include/mlir-c/BuiltinTypes.h | 19 +++++++++++++++---- mlir/lib/Bindings/Python/IRTypes.cpp | 24 ++++++++++++++++++++++++ mlir/lib/CAPI/IR/BuiltinTypes.cpp | 13 +++++++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 13 +++++++++++++ 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/BuiltinTypes.h b/mlir/include/mlir-c/BuiltinTypes.h index 6875fab7b..c981bfd09 100644 --- a/mlir/include/mlir-c/BuiltinTypes.h +++ b/mlir/include/mlir-c/BuiltinTypes.h @@ -289,9 +289,12 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetRank(MlirType type); /// Checks whether the given shaped type has a static shape. MLIR_CAPI_EXPORTED bool mlirShapedTypeHasStaticShape(MlirType type); -/// Checks wither the dim-th dimension of the given shaped type is dynamic. +/// Checks whether the dim-th dimension of the given shaped type is dynamic. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim); +/// Checks whether the dim-th dimension of the given shaped type is static. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim); + /// Returns the dim-th dimension of the given ranked shaped type. MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim); @@ -300,17 +303,25 @@ MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDimSize(MlirType type, /// in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicSize(int64_t size); +/// Checks whether the given shaped type dimension value is statically-sized. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticSize(int64_t size); + /// Returns the value indicating a dynamic size in a shaped type. Prefer -/// mlirShapedTypeIsDynamicSize to direct comparisons with this value. +/// mlirShapedTypeIsDynamicSize and mlirShapedTypeIsStaticSize to direct +/// comparisons with this value. MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicSize(void); /// Checks whether the given value is used as a placeholder for dynamic strides /// and offsets in shaped types. MLIR_CAPI_EXPORTED bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val); +/// Checks whether the given dimension value of a stride or an offset is +/// statically-sized. +MLIR_CAPI_EXPORTED bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val); + /// Returns the value indicating a dynamic stride or offset in a shaped type. -/// Prefer mlirShapedTypeGetDynamicStrideOrOffset to direct comparisons with -/// this value. +/// Prefer mlirShapedTypeIsDynamicStrideOrOffset and +/// mlirShapedTypeIsStaticStrideOrOffset to direct comparisons with this value. MLIR_CAPI_EXPORTED int64_t mlirShapedTypeGetDynamicStrideOrOffset(void); //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 0f2719c10..b11e3f75b 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -544,6 +544,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { nb::arg("dim"), "Returns whether the dim-th dimension of the given shaped type is " "dynamic."); + c.def( + "is_static_dim", + [](PyShapedType &self, intptr_t dim) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsStaticDim(self, dim); + }, + nb::arg("dim"), + "Returns whether the dim-th dimension of the given shaped type is " + "static."); c.def( "get_dim_size", [](PyShapedType &self, intptr_t dim) { @@ -558,6 +567,12 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { nb::arg("dim_size"), "Returns whether the given dimension size indicates a dynamic " "dimension."); + c.def_static( + "is_static_size", + [](int64_t size) -> bool { return mlirShapedTypeIsStaticSize(size); }, + nb::arg("dim_size"), + "Returns whether the given dimension size indicates a static " + "dimension."); c.def( "is_dynamic_stride_or_offset", [](PyShapedType &self, int64_t val) -> bool { @@ -567,6 +582,15 @@ void mlir::PyShapedType::bindDerived(ClassTy &c) { nb::arg("dim_size"), "Returns whether the given value is used as a placeholder for dynamic " "strides and offsets in shaped types."); + c.def( + "is_static_stride_or_offset", + [](PyShapedType &self, int64_t val) -> bool { + self.requireHasRank(); + return mlirShapedTypeIsStaticStrideOrOffset(val); + }, + nb::arg("dim_size"), + "Returns whether the given shaped type stride or offset value is " + "statically-sized."); c.def_prop_ro( "shape", [](PyShapedType &self) { diff --git a/mlir/lib/CAPI/IR/BuiltinTypes.cpp b/mlir/lib/CAPI/IR/BuiltinTypes.cpp index a080adf0f..9d8554aab 100644 --- a/mlir/lib/CAPI/IR/BuiltinTypes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinTypes.cpp @@ -332,6 +332,11 @@ bool mlirShapedTypeIsDynamicDim(MlirType type, intptr_t dim) { .isDynamicDim(static_cast(dim)); } +bool mlirShapedTypeIsStaticDim(MlirType type, intptr_t dim) { + return llvm::cast(unwrap(type)) + .isStaticDim(static_cast(dim)); +} + int64_t mlirShapedTypeGetDimSize(MlirType type, intptr_t dim) { return llvm::cast(unwrap(type)) .getDimSize(static_cast(dim)); @@ -343,10 +348,18 @@ bool mlirShapedTypeIsDynamicSize(int64_t size) { return ShapedType::isDynamic(size); } +bool mlirShapedTypeIsStaticSize(int64_t size) { + return ShapedType::isStatic(size); +} + bool mlirShapedTypeIsDynamicStrideOrOffset(int64_t val) { return ShapedType::isDynamic(val); } +bool mlirShapedTypeIsStaticStrideOrOffset(int64_t val) { + return ShapedType::isStatic(val); +} + int64_t mlirShapedTypeGetDynamicStrideOrOffset() { return ShapedType::kDynamic; } diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index 70bca3c75..ed476da28 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -2497,6 +2497,11 @@ class ShapedType(Type): Returns whether the given dimension size indicates a dynamic dimension. """ @staticmethod + def is_static_size(dim_size: int) -> bool: + """ + Returns whether the given dimension size indicates a static dimension. + """ + @staticmethod def isinstance(other: Type) -> bool: ... def __init__(self, cast_from_type: Type) -> None: ... def get_dim_size(self, dim: int) -> int: @@ -2507,10 +2512,18 @@ class ShapedType(Type): """ Returns whether the dim-th dimension of the given shaped type is dynamic. """ + def is_static_dim(self, dim: int) -> bool: + """ + Returns whether the dim-th dimension of the given shaped type is static. + """ def is_dynamic_stride_or_offset(self, dim_size: int) -> bool: """ Returns whether the given value is used as a placeholder for dynamic strides and offsets in shaped types. """ + def is_static_stride_or_offset(self, dim_size: int) -> bool: + """ + Returns whether the given shaped type stride or offset value is statically-sized. + """ @property def element_type(self) -> Type: """ From ea7589d391f01ff9c0fef8bda87065db83e07374 Mon Sep 17 00:00:00 2001 From: Rolf Morel Date: Tue, 8 Jul 2025 12:00:34 +0200 Subject: [PATCH 896/915] [MLIR][Transform] Introduce `transform.tune.knob` op (#146732) A new transform op to represent that an attribute is to be chosen from a set of alternatives and that this choice is made available as a `!transform.param`. When a `selected` argument is provided, the op's `apply()` semantics is that of just making this selected attribute available as the result. When `selected` is not provided, `apply()` complains that nothing has resolved the non-determinism that the op is representing. --- mlir/python/CMakeLists.txt | 9 ++ .../dialects/TransformTuneExtensionOps.td | 19 +++++ mlir/python/mlir/dialects/transform/tune.py | 82 +++++++++++++++++++ 3 files changed, 110 insertions(+) create mode 100644 mlir/python/mlir/dialects/TransformTuneExtensionOps.td create mode 100644 mlir/python/mlir/dialects/transform/tune.py diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b2daabb2a..7a0c95ebb 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" DIALECT_NAME transform EXTENSION_NAME transform_debug_extension) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformTuneExtensionOps.td + SOURCES + dialects/transform/tune.py + DIALECT_NAME transform + EXTENSION_NAME transform_tune_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td new file mode 100644 index 000000000..ff3047592 --- /dev/null +++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td @@ -0,0 +1,19 @@ +//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the Tune extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS + +include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/transform/tune.py b/mlir/python/mlir/dialects/transform/tune.py new file mode 100644 index 000000000..f63f88a38 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/tune.py @@ -0,0 +1,82 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from typing import Optional, Sequence + +from ...ir import ( + Type, + Attribute, + ArrayAttr, + StringAttr, + F64Type, + IntegerType, + IntegerAttr, + FloatAttr, + BoolAttr, +) +from .._transform_tune_extension_ops_gen import * +from .._transform_tune_extension_ops_gen import _Dialect + +try: + from .._ods_common import _cext as _ods_cext +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + + +@_ods_cext.register_operation(_Dialect, replace=True) +class KnobOp(KnobOp): + def __init__( + self, + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, + ): + if isinstance(name, str): + name = StringAttr.get(name) + + def map_to_attr(value): + if isinstance(value, bool): + return BoolAttr.get(value) + if isinstance(value, int): + return IntegerAttr.get(IntegerType.get_signless(64), value) + if isinstance(value, float): + return FloatAttr.get(F64Type.get(), value) + if isinstance(value, str): + return StringAttr.get(value) + assert isinstance(value, Attribute) + return value + + if isinstance(options, Sequence) and not isinstance(options, ArrayAttr): + options = ArrayAttr.get([map_to_attr(opt) for opt in options]) + + super().__init__( + result, + name, + options, + selected=selected and map_to_attr(selected), + loc=loc, + ip=ip, + ) + + +def knob( + result: Type, # !transform.any_param or !transform.param + name: Union[StringAttr, str], + options: Union[ + ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute + ], + *, + selected: Optional[Attribute] = None, + loc=None, + ip=None, +): + return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip) From c532a58631ee478ec54aa68617ef1d961c780510 Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Fri, 11 Jul 2025 09:23:39 +0200 Subject: [PATCH 897/915] [mlir] Fix TransformTuneExtensionOps.td include guards --- mlir/python/mlir/dialects/TransformTuneExtensionOps.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td index ff3047592..c622c31e2 100644 --- a/mlir/python/mlir/dialects/TransformTuneExtensionOps.td +++ b/mlir/python/mlir/dialects/TransformTuneExtensionOps.td @@ -11,9 +11,9 @@ // //===----------------------------------------------------------------------===// -#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS -#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#ifndef PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td" -#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS +#endif // PYTHON_BINDINGS_TRANSFORM_TUNE_EXTENSION_OPS From 43ddebc36eea29b15c47c75fc44f3c7ceef924ed Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Tue, 15 Jul 2025 16:58:10 +0200 Subject: [PATCH 898/915] [mlir][py] Mark all type caster `from_{cpp,python}` methods as noexcept (#143866) This is mentioned as a "must" in https://nanobind.readthedocs.io/en/latest/porting.html#type-casters when implementing type casters. While most of the existing `from_cpp` methods were already marked noexcept, many of the `from_python` methods were not. This commit adds the missing noexcept declarations to all type casters found in `NanobindAdaptors.h`. --------- Co-authored-by: Maksim Levental --- .../mlir/Bindings/Python/NanobindAdaptors.h | 114 +++++++++--------- 1 file changed, 56 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 2dd35c097..8dcf91e58 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -20,6 +20,7 @@ #define MLIR_BINDINGS_PYTHON_NANOBINDADAPTORS_H #include +#include #include "mlir-c/Diagnostics.h" #include "mlir-c/IR.h" @@ -43,18 +44,14 @@ namespace detail { /// with a raw handle (unowned). The returned object's lifetime may not extend /// beyond the apiObject handle without explicitly having its refcount increased /// (i.e. on return). -static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { +static std::optional +mlirApiObjectToCapsule(nanobind::handle apiObject) { if (PyCapsule_CheckExact(apiObject.ptr())) return nanobind::borrow(apiObject); nanobind::object api = nanobind::getattr(apiObject, MLIR_PYTHON_CAPI_PTR_ATTR, nanobind::none()); - if (api.is_none()) { - std::string repr = nanobind::cast(nanobind::repr(apiObject)); - throw nanobind::type_error( - (llvm::Twine("Expected an MLIR object (got ") + repr + ").") - .str() - .c_str()); - } + if (api.is_none()) + return {}; return api; } @@ -67,12 +64,9 @@ static nanobind::object mlirApiObjectToCapsule(nanobind::handle apiObject) { template <> struct type_caster { NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToAffineMap(capsule.ptr()); - if (mlirAffineMapIsNull(value)) { - return false; - } + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToAffineMap(capsule->ptr()); return !mlirAffineMapIsNull(value); } static handle from_cpp(MlirAffineMap v, rv_policy, @@ -90,9 +84,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToAttribute(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToAttribute(capsule->ptr()); return !mlirAttributeIsNull(value); } static handle from_cpp(MlirAttribute v, rv_policy, @@ -111,9 +105,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToBlock(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToBlock(capsule->ptr()); return !mlirBlockIsNull(value); } }; @@ -122,7 +116,7 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirContext, const_name("MlirContext")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. // TODO: This raises an error of "No current context" currently. @@ -132,8 +126,8 @@ struct type_caster { .attr("Context") .attr("current"); } - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToContext(capsule.ptr()); + std::optional capsule = mlirApiObjectToCapsule(src); + value = mlirPythonCapsuleToContext(capsule->ptr()); return !mlirContextIsNull(value); } }; @@ -142,9 +136,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToDialectRegistry(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToDialectRegistry(capsule->ptr()); return !mlirDialectRegistryIsNull(value); } static handle from_cpp(MlirDialectRegistry v, rv_policy, @@ -162,15 +156,15 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirLocation, const_name("MlirLocation")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { if (src.is_none()) { // Gets the current thread-bound context. src = nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Location") .attr("current"); } - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToLocation(capsule.ptr()); + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToLocation(capsule->ptr()); return !mlirLocationIsNull(value); } static handle from_cpp(MlirLocation v, rv_policy, @@ -188,9 +182,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToModule(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToModule(capsule->ptr()); return !mlirModuleIsNull(value); } static handle from_cpp(MlirModule v, rv_policy, @@ -209,12 +203,13 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirFrozenRewritePatternSet, const_name("MlirFrozenRewritePatternSet")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr()); return value.ptr != nullptr; } - static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) { + static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, + handle) noexcept { nanobind::object capsule = nanobind::steal( mlirPythonFrozenRewritePatternSetToCapsule(v)); return nanobind::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("rewrite")) @@ -228,9 +223,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToOperation(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToOperation(capsule->ptr()); return !mlirOperationIsNull(value); } static handle from_cpp(MlirOperation v, rv_policy, @@ -250,9 +245,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToValue(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToValue(capsule->ptr()); return !mlirValueIsNull(value); } static handle from_cpp(MlirValue v, rv_policy, @@ -273,9 +268,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToPassManager(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToPassManager(capsule->ptr()); return !mlirPassManagerIsNull(value); } }; @@ -284,9 +279,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToTypeID(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToTypeID(capsule->ptr()); return !mlirTypeIDIsNull(value); } static handle from_cpp(MlirTypeID v, rv_policy, @@ -306,9 +301,9 @@ struct type_caster { template <> struct type_caster { NB_TYPE_CASTER(MlirType, const_name("MlirType")) - bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) { - nanobind::object capsule = mlirApiObjectToCapsule(src); - value = mlirPythonCapsuleToType(capsule.ptr()); + bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { + if (auto capsule = mlirApiObjectToCapsule(src)) + value = mlirPythonCapsuleToType(capsule->ptr()); return !mlirTypeIsNull(value); } static handle from_cpp(MlirType t, rv_policy, @@ -462,9 +457,10 @@ class mlir_attribute_subclass : public pure_subclass { nanobind::object newCf = nanobind::cpp_function( [superCls, isaFunction, captureTypeName]( nanobind::object cls, nanobind::object otherAttribute) { - MlirAttribute rawAttribute = - nanobind::cast(otherAttribute); - if (!isaFunction(rawAttribute)) { + MlirAttribute rawAttribute; + if (!nanobind::try_cast(otherAttribute, + rawAttribute) || + !isaFunction(rawAttribute)) { auto origRepr = nanobind::cast(nanobind::repr(otherAttribute)); throw std::invalid_argument( @@ -543,8 +539,9 @@ class mlir_type_subclass : public pure_subclass { nanobind::object newCf = nanobind::cpp_function( [superCls, isaFunction, captureTypeName](nanobind::object cls, nanobind::object otherType) { - MlirType rawType = nanobind::cast(otherType); - if (!isaFunction(rawType)) { + MlirType rawType; + if (!nanobind::try_cast(otherType, rawType) || + !isaFunction(rawType)) { auto origRepr = nanobind::cast(nanobind::repr(otherType)); throw std::invalid_argument((llvm::Twine("Cannot cast type to ") + @@ -625,8 +622,9 @@ class mlir_value_subclass : public pure_subclass { nanobind::object newCf = nanobind::cpp_function( [superCls, isaFunction, captureValueName](nanobind::object cls, nanobind::object otherValue) { - MlirValue rawValue = nanobind::cast(otherValue); - if (!isaFunction(rawValue)) { + MlirValue rawValue; + if (!nanobind::try_cast(otherValue, rawValue) || + !isaFunction(rawValue)) { auto origRepr = nanobind::cast(nanobind::repr(otherValue)); throw std::invalid_argument((llvm::Twine("Cannot cast value to ") + From c8c967bdeadb2ca65bf0380bc1b1f3c791a8358c Mon Sep 17 00:00:00 2001 From: Jordan Rupprecht Date: Tue, 15 Jul 2025 15:32:26 -0500 Subject: [PATCH 899/915] [mlir][py] Fix nanobind uninitialized values (#148944) After #143866, we no longer always write to `value`, causing it to be uninitialized. This can lead to mysterious crashes, e.g. in `python_test.py` / `testCustomAttribute` when we attempt to evaluate `TestAttr(42)`, it does not set `value`, but `mlirAttributeIsNull(value)` happens to return false for garbage memory, and we end up trying to interpret it as a function instead of skipping it. Fix this by only reading `value` if it has been assigned. If it hasn't, `return false` seems the right choice for all these methods, i.e. indicate that `from_python` failed. --- .../mlir/Bindings/Python/NanobindAdaptors.h | 72 ++++++++++++------- 1 file changed, 48 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h index 8dcf91e58..1428d5ccf 100644 --- a/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/NanobindAdaptors.h @@ -65,9 +65,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirAffineMap, const_name("MlirAffineMap")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAffineMap(capsule->ptr()); - return !mlirAffineMapIsNull(value); + return !mlirAffineMapIsNull(value); + } + return false; } static handle from_cpp(MlirAffineMap v, rv_policy, cleanup_list *cleanup) noexcept { @@ -85,9 +87,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirAttribute, const_name("MlirAttribute")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToAttribute(capsule->ptr()); - return !mlirAttributeIsNull(value); + return !mlirAttributeIsNull(value); + } + return false; } static handle from_cpp(MlirAttribute v, rv_policy, cleanup_list *cleanup) noexcept { @@ -106,9 +110,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirBlock, const_name("MlirBlock")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToBlock(capsule->ptr()); - return !mlirBlockIsNull(value); + return !mlirBlockIsNull(value); + } + return false; } }; @@ -137,9 +143,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirDialectRegistry, const_name("MlirDialectRegistry")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToDialectRegistry(capsule->ptr()); - return !mlirDialectRegistryIsNull(value); + return !mlirDialectRegistryIsNull(value); + } + return false; } static handle from_cpp(MlirDialectRegistry v, rv_policy, cleanup_list *cleanup) noexcept { @@ -163,9 +171,11 @@ struct type_caster { .attr("Location") .attr("current"); } - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToLocation(capsule->ptr()); - return !mlirLocationIsNull(value); + return !mlirLocationIsNull(value); + } + return false; } static handle from_cpp(MlirLocation v, rv_policy, cleanup_list *cleanup) noexcept { @@ -183,9 +193,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirModule, const_name("MlirModule")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToModule(capsule->ptr()); - return !mlirModuleIsNull(value); + return !mlirModuleIsNull(value); + } + return false; } static handle from_cpp(MlirModule v, rv_policy, cleanup_list *cleanup) noexcept { @@ -204,9 +216,11 @@ struct type_caster { NB_TYPE_CASTER(MlirFrozenRewritePatternSet, const_name("MlirFrozenRewritePatternSet")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToFrozenRewritePatternSet(capsule->ptr()); - return value.ptr != nullptr; + return value.ptr != nullptr; + } + return false; } static handle from_cpp(MlirFrozenRewritePatternSet v, rv_policy, handle) noexcept { @@ -224,9 +238,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirOperation, const_name("MlirOperation")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToOperation(capsule->ptr()); - return !mlirOperationIsNull(value); + return !mlirOperationIsNull(value); + } + return false; } static handle from_cpp(MlirOperation v, rv_policy, cleanup_list *cleanup) noexcept { @@ -246,9 +262,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirValue, const_name("MlirValue")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToValue(capsule->ptr()); - return !mlirValueIsNull(value); + return !mlirValueIsNull(value); + } + return false; } static handle from_cpp(MlirValue v, rv_policy, cleanup_list *cleanup) noexcept { @@ -269,9 +287,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirPassManager, const_name("MlirPassManager")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToPassManager(capsule->ptr()); - return !mlirPassManagerIsNull(value); + return !mlirPassManagerIsNull(value); + } + return false; } }; @@ -280,9 +300,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirTypeID, const_name("MlirTypeID")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToTypeID(capsule->ptr()); - return !mlirTypeIDIsNull(value); + return !mlirTypeIDIsNull(value); + } + return false; } static handle from_cpp(MlirTypeID v, rv_policy, cleanup_list *cleanup) noexcept { @@ -302,9 +324,11 @@ template <> struct type_caster { NB_TYPE_CASTER(MlirType, const_name("MlirType")) bool from_python(handle src, uint8_t flags, cleanup_list *cleanup) noexcept { - if (auto capsule = mlirApiObjectToCapsule(src)) + if (auto capsule = mlirApiObjectToCapsule(src)) { value = mlirPythonCapsuleToType(capsule->ptr()); - return !mlirTypeIsNull(value); + return !mlirTypeIsNull(value); + } + return false; } static handle from_cpp(MlirType t, rv_policy, cleanup_list *cleanup) noexcept { From 452099bbc4fc6bbe443e9ff01d698ba767b0f7d0 Mon Sep 17 00:00:00 2001 From: Martin Erhart Date: Wed, 16 Jul 2025 15:45:15 +0100 Subject: [PATCH 900/915] [mlir] Add Python bindings to enable default passmanager timing (#149087) --- mlir/include/mlir-c/Pass.h | 4 ++++ mlir/lib/Bindings/Python/Pass.cpp | 6 ++++++ mlir/lib/CAPI/IR/Pass.cpp | 4 ++++ 3 files changed, 14 insertions(+) diff --git a/mlir/include/mlir-c/Pass.h b/mlir/include/mlir-c/Pass.h index 8fd8e9956..0d2e19ee7 100644 --- a/mlir/include/mlir-c/Pass.h +++ b/mlir/include/mlir-c/Pass.h @@ -88,6 +88,10 @@ MLIR_CAPI_EXPORTED void mlirPassManagerEnableIRPrinting( MLIR_CAPI_EXPORTED void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable); +/// Enable pass timing. +MLIR_CAPI_EXPORTED void +mlirPassManagerEnableTiming(MlirPassManager passManager); + /// Nest an OpPassManager under the top-level PassManager, the nested /// passmanager will only run on operations matching the provided name. /// The returned OpPassManager will be destroyed when the parent is destroyed. diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 858c3bd57..8d84864b9 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -112,6 +112,12 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { mlirPassManagerEnableVerifier(passManager.get(), enable); }, "enable"_a, "Enable / disable verify-each.") + .def( + "enable_timing", + [](PyPassManager &passManager) { + mlirPassManagerEnableTiming(passManager.get()); + }, + "Enable pass timing.") .def_static( "parse", [](const std::string &pipeline, DefaultingPyMlirContext context) { diff --git a/mlir/lib/CAPI/IR/Pass.cpp b/mlir/lib/CAPI/IR/Pass.cpp index 883b7e8bb..3c499c3e4 100644 --- a/mlir/lib/CAPI/IR/Pass.cpp +++ b/mlir/lib/CAPI/IR/Pass.cpp @@ -75,6 +75,10 @@ void mlirPassManagerEnableVerifier(MlirPassManager passManager, bool enable) { unwrap(passManager)->enableVerifier(enable); } +void mlirPassManagerEnableTiming(MlirPassManager passManager) { + unwrap(passManager)->enableTiming(); +} + MlirOpPassManager mlirPassManagerGetNestedUnder(MlirPassManager passManager, MlirStringRef operationName) { return wrap(&unwrap(passManager)->nest(unwrap(operationName))); From f178c8ce28b67a810f3839a3a645a377fcfe9334 Mon Sep 17 00:00:00 2001 From: Akshay Khadse Date: Thu, 17 Jul 2025 12:57:04 -0400 Subject: [PATCH 901/915] [MLIR][Python] Support eliding large resource strings in PassManager (#149187) - Introduces a `large_resource_limit` parameter across Python bindings, enabling the eliding of resource strings exceeding a specified character limit during IR printing. - To maintain backward compatibilty, when using `operation.print()` API, if `large_resource_limit` is None and the `large_elements_limit` is set, the later will be used to elide the resource string as well. This change was introduced by https://github.com/llvm/llvm-project/pull/125738. - For printing using pass manager, the `large_resource_limit` and `large_elements_limit` are completely independent of each other. --- mlir/lib/Bindings/Python/IRCore.cpp | 22 ++++++++++++++----- mlir/lib/Bindings/Python/IRModule.h | 14 ++++++------ mlir/lib/Bindings/Python/Pass.cpp | 12 ++++++++-- mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 7 ++++++ .../mlir/_mlir_libs/_mlir/passmanager.pyi | 1 + 5 files changed, 41 insertions(+), 15 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index d96148288..7b790e90e 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -97,6 +97,10 @@ static const char kOperationPrintDocstring[] = binary: Whether to write bytes (True) or str (False). Defaults to False. large_elements_limit: Whether to elide elements attributes above this number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource attributes above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. enable_debug_info: Whether to print debug/location information. Defaults to False. pretty_debug_info: Whether to format debug information for easier reading @@ -1303,6 +1307,7 @@ void PyOperation::checkValid() const { } void PyOperationBase::print(std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, @@ -1314,10 +1319,10 @@ void PyOperationBase::print(std::optional largeElementsLimit, fileObject = nb::module_::import_("sys").attr("stdout"); MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) { + if (largeElementsLimit) mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); - mlirOpPrintingFlagsElideLargeResourceString(flags, *largeElementsLimit); - } + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, *largeResourceLimit); if (enableDebugInfo) mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/prettyDebugInfo); @@ -1405,6 +1410,7 @@ void PyOperationBase::walk( nb::object PyOperationBase::getAsm(bool binary, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, @@ -1416,6 +1422,7 @@ nb::object PyOperationBase::getAsm(bool binary, fileObject = nb::module_::import_("io").attr("StringIO")(); } print(/*largeElementsLimit=*/largeElementsLimit, + /*largeResourceLimit=*/largeResourceLimit, /*enableDebugInfo=*/enableDebugInfo, /*prettyDebugInfo=*/prettyDebugInfo, /*printGenericOpForm=*/printGenericOpForm, @@ -3348,6 +3355,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](PyOperationBase &self) { return self.getAsm(/*binary=*/false, /*largeElementsLimit=*/std::nullopt, + /*largeResourceLimit=*/std::nullopt, /*enableDebugInfo=*/false, /*prettyDebugInfo=*/false, /*printGenericOpForm=*/false, @@ -3363,11 +3371,12 @@ void mlir::python::populateIRCore(nb::module_ &m) { nb::arg("state"), nb::arg("file").none() = nb::none(), nb::arg("binary") = false, kOperationPrintStateDocstring) .def("print", - nb::overload_cast, bool, bool, bool, bool, - bool, bool, nb::object, bool, bool>( - &PyOperationBase::print), + nb::overload_cast, std::optional, + bool, bool, bool, bool, bool, bool, nb::object, + bool, bool>(&PyOperationBase::print), // Careful: Lots of arguments must match up with print method. nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, @@ -3383,6 +3392,7 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Careful: Lots of arguments must match up with get_asm method. nb::arg("binary") = false, nb::arg("large_elements_limit").none() = nb::none(), + nb::arg("large_resource_limit").none() = nb::none(), nb::arg("enable_debug_info") = false, nb::arg("pretty_debug_info") = false, nb::arg("print_generic_op_form") = false, diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9befcce72..0fdd2d1a7 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -599,18 +599,18 @@ class PyOperationBase { public: virtual ~PyOperationBase() = default; /// Implements the bound 'print' method and helps with others. - void print(std::optional largeElementsLimit, bool enableDebugInfo, + void print(std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, bool useNameLocAsPrefix, bool assumeVerified, nanobind::object fileObject, bool binary, bool skipRegions); void print(PyAsmState &state, nanobind::object fileObject, bool binary); - nanobind::object getAsm(bool binary, - std::optional largeElementsLimit, - bool enableDebugInfo, bool prettyDebugInfo, - bool printGenericOpForm, bool useLocalScope, - bool useNameLocAsPrefix, bool assumeVerified, - bool skipRegions); + nanobind::object + getAsm(bool binary, std::optional largeElementsLimit, + std::optional largeResourceLimit, bool enableDebugInfo, + bool prettyDebugInfo, bool printGenericOpForm, bool useLocalScope, + bool useNameLocAsPrefix, bool assumeVerified, bool skipRegions); // Implement the bound 'writeBytecode' method. void writeBytecode(const nanobind::object &fileObject, diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 8d84864b9..20017e25b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -78,12 +78,19 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { [](PyPassManager &passManager, bool printBeforeAll, bool printAfterAll, bool printModuleScope, bool printAfterChange, bool printAfterFailure, std::optional largeElementsLimit, - bool enableDebugInfo, bool printGenericOpForm, + std::optional largeResourceLimit, bool enableDebugInfo, + bool printGenericOpForm, std::optional optionalTreePrintingPath) { MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); - if (largeElementsLimit) + if (largeElementsLimit) { mlirOpPrintingFlagsElideLargeElementsAttrs(flags, *largeElementsLimit); + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeElementsLimit); + } + if (largeResourceLimit) + mlirOpPrintingFlagsElideLargeResourceString(flags, + *largeResourceLimit); if (enableDebugInfo) mlirOpPrintingFlagsEnableDebugInfo(flags, /*enable=*/true, /*prettyForm=*/false); @@ -103,6 +110,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "print_module_scope"_a = false, "print_after_change"_a = false, "print_after_failure"_a = false, "large_elements_limit"_a.none() = nb::none(), + "large_resource_limit"_a.none() = nb::none(), "enable_debug_info"_a = false, "print_generic_op_form"_a = false, "tree_printing_dir_path"_a.none() = nb::none(), "Enable IR printing, default as mlir-print-ir-after-all.") diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index ed476da28..be71737e4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -200,6 +200,7 @@ class _OperationBase: def get_asm( binary: Literal[True], large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -212,6 +213,7 @@ class _OperationBase: self, binary: bool = False, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -253,6 +255,7 @@ class _OperationBase: def print( self, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, pretty_debug_info: bool = False, print_generic_op_form: bool = False, @@ -270,6 +273,10 @@ class _OperationBase: binary: Whether to write bytes (True) or str (False). Defaults to False. large_elements_limit: Whether to elide elements attributes above this number of elements. Defaults to None (no limit). + large_resource_limit: Whether to elide resource strings above this + number of characters. Defaults to None (no limit). If large_elements_limit + is set and this is None, the behavior will be to use large_elements_limit + as large_resource_limit. enable_debug_info: Whether to print debug/location information. Defaults to False. pretty_debug_info: Whether to format debug information for easier reading diff --git a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi index 0d2eaffe1..1010dadda 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/passmanager.pyi @@ -23,6 +23,7 @@ class PassManager: print_after_change: bool = False, print_after_failure: bool = False, large_elements_limit: int | None = None, + large_resource_limit: int | None = None, enable_debug_info: bool = False, print_generic_op_form: bool = False, tree_printing_dir_path: str | None = None, From 3447fdc21b21a4785f7613f30055d084bfd30184 Mon Sep 17 00:00:00 2001 From: Colin De Vlieghere Date: Fri, 18 Jul 2025 16:53:11 -0700 Subject: [PATCH 902/915] [MLIR][SCF] Add dedicated Python bindings for ForallOp (#149416) This patch specializes the Python bindings for ForallOp and InParallelOp, similar to the existing one for ForOp. These bindings create the regions and blocks properly and expose some additional helpers. --- mlir/python/mlir/dialects/scf.py | 119 ++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 1 deletion(-) diff --git a/mlir/python/mlir/dialects/scf.py b/mlir/python/mlir/dialects/scf.py index 2d0047b76..678ceeeba 100644 --- a/mlir/python/mlir/dialects/scf.py +++ b/mlir/python/mlir/dialects/scf.py @@ -17,7 +17,7 @@ except ImportError as e: raise RuntimeError("Error loading imports from extension module") from e -from typing import Optional, Sequence, Union +from typing import List, Optional, Sequence, Tuple, Union @_ods_cext.register_operation(_Dialect, replace=True) @@ -71,6 +71,123 @@ def inner_iter_args(self): return self.body.arguments[1:] +def _dispatch_index_op_fold_results( + ofrs: Sequence[Union[Operation, OpView, Value, int]], +) -> Tuple[List[Value], List[int]]: + """`mlir::dispatchIndexOpFoldResults`""" + dynamic_vals = [] + static_vals = [] + for ofr in ofrs: + if isinstance(ofr, (Operation, OpView, Value)): + val = _get_op_result_or_value(ofr) + dynamic_vals.append(val) + static_vals.append(ShapedType.get_dynamic_size()) + else: + static_vals.append(ofr) + return dynamic_vals, static_vals + + +@_ods_cext.register_operation(_Dialect, replace=True) +class ForallOp(ForallOp): + """Specialization for the SCF forall op class.""" + + def __init__( + self, + lower_bounds: Sequence[Union[Operation, OpView, Value, int]], + upper_bounds: Sequence[Union[Operation, OpView, Value, int]], + steps: Sequence[Union[Value, int]], + shared_outs: Optional[Union[Operation, OpView, Sequence[Value]]] = None, + *, + mapping=None, + loc=None, + ip=None, + ): + """Creates an SCF `forall` operation. + + - `lower_bounds` are the values to use as lower bounds of the loop. + - `upper_bounds` are the values to use as upper bounds of the loop. + - `steps` are the values to use as loop steps. + - `shared_outs` is a list of additional loop-carried arguments or an operation + producing them as results. + """ + assert ( + len(lower_bounds) == len(upper_bounds) == len(steps) + ), "Mismatch in length of lower bounds, upper bounds, and steps" + if shared_outs is None: + shared_outs = [] + shared_outs = _get_op_results_or_values(shared_outs) + + dynamic_lbs, static_lbs = _dispatch_index_op_fold_results(lower_bounds) + dynamic_ubs, static_ubs = _dispatch_index_op_fold_results(upper_bounds) + dynamic_steps, static_steps = _dispatch_index_op_fold_results(steps) + + results = [arg.type for arg in shared_outs] + super().__init__( + results, + dynamic_lbs, + dynamic_ubs, + dynamic_steps, + static_lbs, + static_ubs, + static_steps, + shared_outs, + mapping=mapping, + loc=loc, + ip=ip, + ) + rank = len(static_lbs) + iv_types = [IndexType.get()] * rank + self.regions[0].blocks.append(*iv_types, *results) + + @property + def body(self) -> Block: + """Returns the body (block) of the loop.""" + return self.regions[0].blocks[0] + + @property + def rank(self) -> int: + """Returns the number of induction variables the loop has.""" + return len(self.staticLowerBound) + + @property + def induction_variables(self) -> BlockArgumentList: + """Returns the induction variables usable within the loop.""" + return self.body.arguments[: self.rank] + + @property + def inner_iter_args(self) -> BlockArgumentList: + """Returns the loop-carried arguments usable within the loop. + + To obtain the loop-carried operands, use `iter_args`. + """ + return self.body.arguments[self.rank :] + + def terminator(self) -> InParallelOp: + """ + Returns the loop terminator if it exists. + Otherwise, creates a new one. + """ + ops = self.body.operations + with InsertionPoint(self.body): + if not ops: + return InParallelOp() + last = ops[len(ops) - 1] + return last if isinstance(last, InParallelOp) else InParallelOp() + + +@_ods_cext.register_operation(_Dialect, replace=True) +class InParallelOp(InParallelOp): + """Specialization of the SCF forall.in_parallel op class.""" + + def __init__(self, loc=None, ip=None): + super().__init__(loc=loc, ip=ip) + self.region.blocks.append() + + @property + def block(self) -> Block: + return self.region.blocks[0] + + @_ods_cext.register_operation(_Dialect, replace=True) class IfOp(IfOp): """Specialization for the SCF if op class.""" From 36b655fb98c83263107d1f8f460b9049ee1bc209 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Wed, 23 Jul 2025 12:33:42 -0500 Subject: [PATCH 903/915] [mlir][python,CAPI] expose Op::isBeforeInBlock (#150271) --- mlir/include/mlir-c/IR.h | 7 +++++++ mlir/lib/Bindings/Python/IRCore.cpp | 15 +++++++++++++++ mlir/lib/Bindings/Python/IRModule.h | 7 +++++++ mlir/lib/CAPI/IR/IR.cpp | 4 ++++ 4 files changed, 33 insertions(+) diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 81299c791..71c7d4378 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -813,6 +813,13 @@ MLIR_CAPI_EXPORTED void mlirOperationMoveAfter(MlirOperation op, MLIR_CAPI_EXPORTED void mlirOperationMoveBefore(MlirOperation op, MlirOperation other); +/// Given an operation 'other' that is within the same parent block, return +/// whether the current operation is before 'other' in the operation list +/// of the parent block. +/// Note: This function has an average complexity of O(1), but worst case may +/// take O(N) where N is the number of operations within the parent block. +MLIR_CAPI_EXPORTED bool mlirOperationIsBeforeInBlock(MlirOperation op, + MlirOperation other); /// Operation walk result. typedef enum MlirWalkResult { MlirWalkResultAdvance, diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7b790e90e..5feed95f9 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1454,6 +1454,14 @@ void PyOperationBase::moveBefore(PyOperationBase &other) { operation.parentKeepAlive = otherOp.parentKeepAlive; } +bool PyOperationBase::isBeforeInBlock(PyOperationBase &other) { + PyOperation &operation = getOperation(); + PyOperation &otherOp = other.getOperation(); + operation.checkValid(); + otherOp.checkValid(); + return mlirOperationIsBeforeInBlock(operation, otherOp); +} + bool PyOperationBase::verify() { PyOperation &op = getOperation(); PyMlirContext::ErrorCapture errors(op.getContext()); @@ -3409,6 +3417,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("move_before", &PyOperationBase::moveBefore, nb::arg("other"), "Puts self immediately before the other operation in its parent " "block.") + .def("is_before_in_block", &PyOperationBase::isBeforeInBlock, + nb::arg("other"), + "Given an operation 'other' that is within the same parent block, " + "return" + "whether the current operation is before 'other' in the operation " + "list" + "of the parent block.") .def( "clone", [](PyOperationBase &self, nb::object ip) { diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 0fdd2d1a7..9c22dea15 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -624,6 +624,13 @@ class PyOperationBase { void moveAfter(PyOperationBase &other); void moveBefore(PyOperationBase &other); + /// Given an operation 'other' that is within the same parent block, return + /// whether the current operation is before 'other' in the operation list + /// of the parent block. + /// Note: This function has an average complexity of O(1), but worst case may + /// take O(N) where N is the number of operations within the parent block. + bool isBeforeInBlock(PyOperationBase &other); + /// Verify the operation. Throws `MLIRError` if verification fails, and /// returns `true` otherwise. bool verify(); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index fbc66bcf5..8491553da 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -850,6 +850,10 @@ void mlirOperationMoveBefore(MlirOperation op, MlirOperation other) { return unwrap(op)->moveBefore(unwrap(other)); } +bool mlirOperationIsBeforeInBlock(MlirOperation op, MlirOperation other) { + return unwrap(op)->isBeforeInBlock(unwrap(other)); +} + static mlir::WalkResult unwrap(MlirWalkResult result) { switch (result) { case MlirWalkResultAdvance: From 39e7ca59d50f9826438ab16608d332e51b002b6b Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Fri, 25 Jul 2025 07:05:30 -0500 Subject: [PATCH 904/915] [mlir][python] fix PyDenseResourceElementsAttribute finalizer (#150561) This PR melds https://github.com/llvm/llvm-project/pull/150137 and https://github.com/llvm/llvm-project/pull/149414 *and* partially reverts https://github.com/llvm/llvm-project/pull/124832. The summary is the `PyDenseResourceElementsAttribute` finalizer/deleter has/had two problems 1. wasn't threadsafe (can be called from a different thread than that which currently holds the GIL) 2. can be called while the interpreter is "not initialized" https://github.com/llvm/llvm-project/pull/124832 for some reason decides to re-initialize the interpreter to avoid case 2 and runs afoul of the fact that `Py_IsInitialized` can be false during the finalization of the interpreter itself (e.g., at the end of a script). I don't know why this decision was made (I missed the PR) but I believe we should never be calling [Py_Initialize](https://docs.python.org/3/c-api/init.html#c.Py_Initialize): > In an application \*\*\*\***embedding Python**\*\*\*\*, this should be called before using any other Python/C API functions **but we aren't embedding Python**! So therefore we will only be in case 2 when the interpreter is being finalized and in that case we should just leak the buffer. Note, [lldb](https://github.com/llvm/llvm-project/blob/548ca9e97673a168023a616d311d901ca04b29a3/lldb/source/Plugins/ScriptInterpreter/Python/PythonDataObjects.cpp#L81-L93) does a similar sort of thing for its finalizers. Co-authored-by: Anton Korobeynikov Co-authored-by: Max Manainen Co-authored-by: Anton Korobeynikov Co-authored-by: Max Manainen --- mlir/lib/Bindings/Python/IRAttributes.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRAttributes.cpp b/mlir/lib/Bindings/Python/IRAttributes.cpp index 8f79caf08..db84ee1fc 100644 --- a/mlir/lib/Bindings/Python/IRAttributes.cpp +++ b/mlir/lib/Bindings/Python/IRAttributes.cpp @@ -16,8 +16,8 @@ #include "NanobindUtils.h" #include "mlir-c/BuiltinAttributes.h" #include "mlir-c/BuiltinTypes.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/Support/raw_ostream.h" @@ -1428,6 +1428,12 @@ class PyDenseIntElementsAttribute } }; +// Check if the python version is less than 3.13. Py_IsFinalizing is a part +// of stable ABI since 3.13 and before it was available as _Py_IsFinalizing. +#if PY_VERSION_HEX < 0x030d0000 +#define Py_IsFinalizing _Py_IsFinalizing +#endif + class PyDenseResourceElementsAttribute : public PyConcreteAttribute { public: @@ -1474,8 +1480,9 @@ class PyDenseResourceElementsAttribute // The userData is a Py_buffer* that the deleter owns. auto deleter = [](void *userData, const void *data, size_t size, size_t align) { - if (!Py_IsInitialized()) - Py_Initialize(); + if (Py_IsFinalizing()) + return; + assert(Py_IsInitialized() && "expected interpreter to be initialized"); Py_buffer *ownedView = static_cast(userData); nb::gil_scoped_acquire gil; PyBuffer_Release(ownedView); From b67670c8450f52e8fd87cac87958dde2bd1993a2 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 29 Jul 2025 12:21:52 +0200 Subject: [PATCH 905/915] [mlir][core] Move `InitAll***` implementation into static library. (#150805) `InitAll***` functions are used by `opt`-style tools to init all MLIR dialects/passes/extensions. Currently they are implemeted as inline functions and include essentially the entire MLIR header tree. Each file which includes this header (~10 currently) takes 10+ sec and multiple GB of ram to compile (tested with clang-19), which limits amount of parallel compiler jobs which can be run. Also, flang just includes this file into one of its headers. Move the actual registration code to the static library, so it's compiled only once. Discourse thread https://discourse.llvm.org/t/rfc-moving-initall-implementation-into-static-library/87559 --- mlir/lib/CAPI/RegisterEverything/CMakeLists.txt | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 8b9a39558..ccda668ec 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,19 +1,16 @@ # Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC - ${dialect_libs} ${translation_libs} - ${conversion_libs} - ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) From 272a21e6302896ac85058fb563e4ee77116cbe11 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Tue, 29 Jul 2025 12:26:47 +0200 Subject: [PATCH 906/915] Revert "[mlir][core] Move `InitAll***` implementation into static library." (#151118) Reverts llvm/llvm-project#150805 Some bots are failing. --- mlir/lib/CAPI/RegisterEverything/CMakeLists.txt | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index ccda668ec..8b9a39558 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,16 +1,19 @@ # Dialect registration. +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) +get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) +get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC + ${dialect_libs} ${translation_libs} + ${conversion_libs} + ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRCAPITransforms MLIRLLVMToLLVMIRTranslation - MLIRRegisterAllDialects - MLIRRegisterAllExtensions - MLIRRegisterAllPasses + MLIRCAPITransforms ) From 04f9312e82e498f1d7b5b3230855e22884569f4d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Tue, 29 Jul 2025 17:15:33 +0200 Subject: [PATCH 907/915] [mlir] Reland `Move InitAll*** implementation into static library` (#151150) Reland https://github.com/llvm/llvm-project/pull/150805 Shared libs build was broken. Add `${dialect_libs}` and `${conversion_libs}` to `MLIRRegisterAllExtensions` because it depends on `registerConvert***ToLLVMInterface` functions. --- mlir/lib/CAPI/RegisterEverything/CMakeLists.txt | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt index 8b9a39558..ccda668ec 100644 --- a/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt +++ b/mlir/lib/CAPI/RegisterEverything/CMakeLists.txt @@ -1,19 +1,16 @@ # Dialect registration. -get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) get_property(translation_libs GLOBAL PROPERTY MLIR_TRANSLATION_LIBS) -get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) -get_property(extension_libs GLOBAL PROPERTY MLIR_EXTENSION_LIBS) add_mlir_upstream_c_api_library(MLIRCAPIRegisterEverything RegisterEverything.cpp LINK_LIBS PUBLIC - ${dialect_libs} ${translation_libs} - ${conversion_libs} - ${extension_libs} MLIRBuiltinToLLVMIRTranslation MLIRCAPIIR - MLIRLLVMToLLVMIRTranslation MLIRCAPITransforms + MLIRLLVMToLLVMIRTranslation + MLIRRegisterAllDialects + MLIRRegisterAllExtensions + MLIRRegisterAllPasses ) From 1c23921a58dae60eda252025cf7a64e47921d5a5 Mon Sep 17 00:00:00 2001 From: Renato Golin Date: Fri, 8 Aug 2025 22:20:27 +0100 Subject: [PATCH 908/915] [MLIR][Linalg] Remove matmul_transpose variants (#147961) Removes the `(batch_)matmul_transpose_{a|b}` variants from OpDSL and replace it with `matmul affine_maps [...]` whenever appropriate. This is in line with the [plan](https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863), and can be done since #104783 merged. See: https://discourse.llvm.org/t/deprecate-batch-matmul-transpose-a-b-linalg-operations/87245 Issues investigated: * pad transform tests that could use `matmul` instead, so change to that. * ArmSME test using transpose actually needed it, so changed to `matmul` + affine maps. Arm tests validated by @banach-space (thanks!!). --- .../linalg/opdsl/ops/core_named_ops.py | 93 ------------------- 1 file changed, 93 deletions(-) diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py index 1b359da40..fd4a5a848 100644 --- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py +++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py @@ -373,42 +373,6 @@ def quantized_matmul( ) -@linalg_structured_op -def matmul_transpose_a( - A=TensorDef(T1, S.K, S.N), - B=TensorDef(T2, S.K, S.M), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Performs a matrix multiplication of two 2D inputs with lhs operand - transposed. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.k, D.m]) * cast(U, B[D.k, D.n]) - - -@linalg_structured_op -def matmul_transpose_b( - A=TensorDef(T1, S.M, S.K), - B=TensorDef(T2, S.N, S.K), - C=TensorDef(U, S.M, S.N, output=True), - cast=TypeFnAttrDef(default=TypeFn.cast_signed), -): - """Performs a matrix multiplication of two 2D inputs with rhs operand - transposed. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.n, D.k]) - - @linalg_structured_op def mmt4d( lhs=TensorDef(TV.LhsType, S.M, S.K, S.M0, S.K0), @@ -453,44 +417,6 @@ def batch_mmt4d( ) * TypeFn.cast_signed(TV.AccumType, rhs[D.b, D.n, D.k, D.n0, D.k0]) -@linalg_structured_op -def batch_matmul_transpose_a( - A=TensorDef(T1, Batch, S.K, S.M), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, Batch, S.M, S.N, output=True), -): - """Performs a batched matrix multiplication of two 3D inputs where lhs operand - has its non-batch dimensions transposed. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.k, D.m]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n] - ) - - -@linalg_structured_op -def batch_matmul_transpose_b( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.N, S.K), - C=TensorDef(U, Batch, S.M, S.N, output=True), -): - """Performs a batched matrix multiplication of two 3D inputs where rhs operand - has its non-batch dimensions transposed. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.b, D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.n, D.k] - ) - - @linalg_structured_op def quantized_batch_matmul( A=TensorDef(T1, Batch, S.M, S.K), @@ -512,25 +438,6 @@ def quantized_batch_matmul( ) * (TypeFn.cast_signed(U, B[D.b, D.k, D.n]) - TypeFn.cast_signed(U, BZp)) -@linalg_structured_op -def batch_reduce_matmul( - A=TensorDef(T1, Batch, S.M, S.K), - B=TensorDef(T2, Batch, S.K, S.N), - C=TensorDef(U, S.M, S.N, output=True), -): - """Performs a batch-reduce matrix multiplication of two 3D inputs. - The partial multiplication results are reduced into a 2D output. - - Numeric casting is performed on the operands to the inner multiply, promoting - them to the same data type as the accumulator/output. - """ - domain(D.b, D.m, D.n, D.k) - implements(ContractionOpInterface) - C[D.m, D.n] += TypeFn.cast_signed(U, A[D.b, D.m, D.k]) * TypeFn.cast_signed( - U, B[D.b, D.k, D.n] - ) - - @linalg_structured_op def matvec( A=TensorDef(T1, S.M, S.N), y=TensorDef(T2, S.N), x=TensorDef(U, S.M, output=True) From d1f22c8be41a07c72fcc306cc8b7eeb85620e1c0 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Mon, 11 Aug 2025 12:21:59 -0500 Subject: [PATCH 909/915] [mlir][python] expose isAttached (#153045) --- mlir/lib/Bindings/Python/IRCore.cpp | 8 ++++++++ mlir/python/mlir/_mlir_libs/_mlir/ir.pyi | 7 +++++++ 2 files changed, 15 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5feed95f9..ee88aa475 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -3442,6 +3442,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { return operation.createOpView(); }, "Detaches the operation from its parent block.") + .def_prop_ro( + "attached", + [](PyOperationBase &self) { + PyOperation &operation = self.getOperation(); + operation.checkValid(); + return operation.isAttached(); + }, + "Reports if the operation is attached to its parent block.") .def("erase", [](PyOperationBase &self) { self.getOperation().erase(); }) .def("walk", &PyOperationBase::walk, nb::arg("callback"), nb::arg("walk_order") = MlirWalkPostOrder); diff --git a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi index be71737e4..dcae3dd74 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/ir.pyi @@ -194,6 +194,13 @@ class _OperationBase: """ Detaches the operation from its parent block. """ + + @property + def attached(self) -> bool: + """ + Reports if the operation is attached to its parent block. + """ + def erase(self) -> None: ... @overload From 5e8bc18f855cb08f80f5c1ee180070bea0552bf7 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 12 Aug 2025 17:59:59 -0400 Subject: [PATCH 910/915] [mlir][python] automatic location inference (#151246) This PR implements "automatic" location inference in the bindings. The way it works is it walks the frame stack collecting source locations (Python captures these in the frame itself). It is inspired by JAX's [implementation](https://github.com/jax-ml/jax/blob/523ddcfbcad005deab5a7d542df4c706f5ee5e9c/jax/_src/interpreters/mlir.py#L462) but moves the frame stack traversal into the bindings for better performance. The system supports registering "included" and "excluded" filenames; frames originating from functions in included filenames **will not** be filtered and frames originating from functions in excluded filenames **will** be filtered (in that order). This allows excluding all the generated `*_ops_gen.py` files. The system is also "toggleable" and off by default to save people who have their own systems (such as JAX) from the added cost. Note, the system stores the entire stacktrace (subject to `locTracebackFramesLimit`) in the `Location` using specifically a `CallSiteLoc`. This can be useful for profiling tools (flamegraphs etc.). Shoutout to the folks at JAX for coming up with a good system. --------- Co-authored-by: Jacques Pienaar --- mlir/lib/Bindings/Python/Globals.h | 39 ++++++++ mlir/lib/Bindings/Python/IRCore.cpp | 122 +++++++++++++++++++---- mlir/lib/Bindings/Python/IRModule.cpp | 70 ++++++++++++- mlir/lib/Bindings/Python/IRModule.h | 15 +-- mlir/lib/Bindings/Python/MainModule.cpp | 23 ++++- mlir/python/mlir/dialects/_ods_common.py | 8 +- 6 files changed, 239 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 826a34a53..71a051cb3 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -10,15 +10,19 @@ #define MLIR_BINDINGS_PYTHON_GLOBALS_H #include +#include #include +#include #include #include "NanobindUtils.h" #include "mlir-c/IR.h" #include "mlir/CAPI/Support.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSet.h" +#include "llvm/Support/Regex.h" namespace mlir { namespace python { @@ -114,6 +118,39 @@ class PyGlobals { std::optional lookupOperationClass(llvm::StringRef operationName); + class TracebackLoc { + public: + bool locTracebacksEnabled(); + + void setLocTracebacksEnabled(bool value); + + size_t locTracebackFramesLimit(); + + void setLocTracebackFramesLimit(size_t value); + + void registerTracebackFileInclusion(const std::string &file); + + void registerTracebackFileExclusion(const std::string &file); + + bool isUserTracebackFilename(llvm::StringRef file); + + static constexpr size_t kMaxFrames = 512; + + private: + nanobind::ft_mutex mutex; + bool locTracebackEnabled_ = false; + size_t locTracebackFramesLimit_ = 10; + std::unordered_set userTracebackIncludeFiles; + std::unordered_set userTracebackExcludeFiles; + std::regex userTracebackIncludeRegex; + bool rebuildUserTracebackIncludeRegex = false; + std::regex userTracebackExcludeRegex; + bool rebuildUserTracebackExcludeRegex = false; + llvm::StringMap isUserTracebackFilenameCache; + }; + + TracebackLoc &getTracebackLoc() { return tracebackLoc; } + private: static PyGlobals *instance; @@ -134,6 +171,8 @@ class PyGlobals { /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; + + TracebackLoc tracebackLoc; }; } // namespace python diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index ee88aa475..1aec7ffe8 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -20,11 +20,8 @@ #include "nanobind/nanobind.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/raw_ostream.h" #include -#include -#include namespace nb = nanobind; using namespace nb::literals; @@ -1523,7 +1520,7 @@ nb::object PyOperation::create(std::string_view name, llvm::ArrayRef operands, std::optional attributes, std::optional> successors, - int regions, DefaultingPyLocation location, + int regions, PyLocation &location, const nb::object &maybeIp, bool inferType) { llvm::SmallVector mlirResults; llvm::SmallVector mlirSuccessors; @@ -1627,7 +1624,7 @@ nb::object PyOperation::create(std::string_view name, if (!operation.ptr) throw nb::value_error("Operation creation failed"); PyOperationRef created = - PyOperation::createDetached(location->getContext(), operation); + PyOperation::createDetached(location.getContext(), operation); maybeInsertOperation(created, maybeIp); return created.getObject(); @@ -1937,9 +1934,9 @@ nb::object PyOpView::buildGeneric( std::optional resultTypeList, nb::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, PyLocation &location, const nb::object &maybeIp) { - PyMlirContextRef context = location->getContext(); + PyMlirContextRef context = location.getContext(); // Class level operation construction metadata. // Operand and result segment specs are either none, which does no @@ -2789,6 +2786,90 @@ class PyOpAttributeMap { PyOperationRef operation; }; +MlirLocation tracebackToLocation(MlirContext ctx) { + size_t framesLimit = + PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); + // Use a thread_local here to avoid requiring a large amount of space. + thread_local std::array + frames; + size_t count = 0; + + nb::gil_scoped_acquire acquire; + PyThreadState *tstate = PyThreadState_GET(); + PyFrameObject *next; + PyFrameObject *pyFrame = PyThreadState_GetFrame(tstate); + // In the increment expression: + // 1. get the next prev frame; + // 2. decrement the ref count on the current frame (in order that it can get + // gc'd, along with any objects in its closure and etc); + // 3. set current = next. + for (; pyFrame != nullptr && count < framesLimit; + next = PyFrame_GetBack(pyFrame), Py_XDECREF(pyFrame), pyFrame = next) { + PyCodeObject *code = PyFrame_GetCode(pyFrame); + auto fileNameStr = + nb::cast(nb::borrow(code->co_filename)); + llvm::StringRef fileName(fileNameStr); + if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName)) + continue; + +#if PY_VERSION_HEX < 0x030b00f0 + std::string name = + nb::cast(nb::borrow(code->co_name)); + llvm::StringRef funcName(name); + int startLine = PyFrame_GetLineNumber(pyFrame); + MlirLocation loc = + mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0); +#else + // co_qualname and PyCode_Addr2Location added in py3.11 + std::string name = + nb::cast(nb::borrow(code->co_qualname)); + llvm::StringRef funcName(name); + int startLine, startCol, endLine, endCol; + int lasti = PyFrame_GetLasti(pyFrame); + if (!PyCode_Addr2Location(code, lasti, &startLine, &startCol, &endLine, + &endCol)) { + throw nb::python_error(); + } + MlirLocation loc = mlirLocationFileLineColRangeGet( + ctx, wrap(fileName), startLine, startCol, endLine, endCol); +#endif + + frames[count] = mlirLocationNameGet(ctx, wrap(funcName), loc); + ++count; + } + // When the loop breaks (after the last iter), current frame (if non-null) + // is leaked without this. + Py_XDECREF(pyFrame); + + if (count == 0) + return mlirLocationUnknownGet(ctx); + + MlirLocation callee = frames[0]; + assert(!mlirLocationIsNull(callee) && "expected non-null callee location"); + if (count == 1) + return callee; + + MlirLocation caller = frames[count - 1]; + assert(!mlirLocationIsNull(caller) && "expected non-null caller location"); + for (int i = count - 2; i >= 1; i--) + caller = mlirLocationCallSiteGet(frames[i], caller); + + return mlirLocationCallSiteGet(callee, caller); +} + +PyLocation +maybeGetTracebackLocation(const std::optional &location) { + if (location.has_value()) + return location.value(); + if (!PyGlobals::get().getTracebackLoc().locTracebacksEnabled()) + return DefaultingPyLocation::resolve(); + + PyMlirContext &ctx = DefaultingPyMlirContext::resolve(); + MlirLocation mlirLoc = tracebackToLocation(ctx.get()); + PyMlirContextRef ref = PyMlirContext::forContext(ctx.get()); + return {ref, mlirLoc}; +} + } // namespace //------------------------------------------------------------------------------ @@ -3052,10 +3133,10 @@ void mlir::python::populateIRCore(nb::module_ &m) { .def("__eq__", [](PyLocation &self, nb::object other) { return false; }) .def_prop_ro_static( "current", - [](nb::object & /*class*/) { + [](nb::object & /*class*/) -> std::optional { auto *loc = PyThreadContextEntry::getDefaultLocation(); if (!loc) - throw nb::value_error("No current Location"); + return std::nullopt; return loc; }, "Gets the Location bound to the current thread or raises ValueError") @@ -3240,8 +3321,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { kModuleParseDocstring) .def_static( "create", - [](DefaultingPyLocation loc) { - MlirModule module = mlirModuleCreateEmpty(loc); + [](const std::optional &loc) { + PyLocation pyLoc = maybeGetTracebackLocation(loc); + MlirModule module = mlirModuleCreateEmpty(pyLoc.get()); return PyModule::forModule(module).releaseObject(); }, nb::arg("loc").none() = nb::none(), "Creates an empty module") @@ -3462,8 +3544,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional> operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const nb::object &maybeIp, - bool inferType) { + const std::optional &location, + const nb::object &maybeIp, bool inferType) { // Unpack/validate operands. llvm::SmallVector mlirOperands; if (operands) { @@ -3475,8 +3557,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { } } + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOperation::create(name, results, mlirOperands, attributes, - successors, regions, location, maybeIp, + successors, regions, pyLoc, maybeIp, inferType); }, nb::arg("name"), nb::arg("results").none() = nb::none(), @@ -3520,12 +3603,14 @@ void mlir::python::populateIRCore(nb::module_ &m) { std::optional resultTypeList, nb::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, + const std::optional &location, const nb::object &maybeIp) { + PyLocation pyLoc = maybeGetTracebackLocation(location); new (self) PyOpView(PyOpView::buildGeneric( name, opRegionSpec, operandSegmentSpecObj, resultSegmentSpecObj, resultTypeList, operandList, - attributes, successors, regions, location, maybeIp)); + attributes, successors, regions, pyLoc, maybeIp)); }, nb::arg("name"), nb::arg("opRegionSpec"), nb::arg("operandSegmentSpecObj").none() = nb::none(), @@ -3559,17 +3644,18 @@ void mlir::python::populateIRCore(nb::module_ &m) { [](nb::handle cls, std::optional resultTypeList, nb::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, std::optional location, const nb::object &maybeIp) { std::string name = nb::cast(cls.attr("OPERATION_NAME")); std::tuple opRegionSpec = nb::cast>(cls.attr("_ODS_REGIONS")); nb::object operandSegmentSpec = cls.attr("_ODS_OPERAND_SEGMENTS"); nb::object resultSegmentSpec = cls.attr("_ODS_RESULT_SEGMENTS"); + PyLocation pyLoc = maybeGetTracebackLocation(location); return PyOpView::buildGeneric(name, opRegionSpec, operandSegmentSpec, resultSegmentSpec, resultTypeList, operandList, attributes, successors, - regions, location, maybeIp); + regions, pyLoc, maybeIp); }, nb::arg("cls"), nb::arg("results").none() = nb::none(), nb::arg("operands").none() = nb::none(), diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e600f1bbd..0de2f1711 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -13,9 +13,9 @@ #include "Globals.h" #include "NanobindUtils.h" +#include "mlir-c/Bindings/Python/Interop.h" #include "mlir-c/Support.h" #include "mlir/Bindings/Python/Nanobind.h" -#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind. namespace nb = nanobind; using namespace mlir; @@ -197,3 +197,71 @@ PyGlobals::lookupOperationClass(llvm::StringRef operationName) { // Not found and loading did not yield a registration. return std::nullopt; } + +bool PyGlobals::TracebackLoc::locTracebacksEnabled() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackEnabled_; +} + +void PyGlobals::TracebackLoc::setLocTracebacksEnabled(bool value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackEnabled_ = value; +} + +size_t PyGlobals::TracebackLoc::locTracebackFramesLimit() { + nanobind::ft_lock_guard lock(mutex); + return locTracebackFramesLimit_; +} + +void PyGlobals::TracebackLoc::setLocTracebackFramesLimit(size_t value) { + nanobind::ft_lock_guard lock(mutex); + locTracebackFramesLimit_ = std::min(value, kMaxFrames); +} + +void PyGlobals::TracebackLoc::registerTracebackFileInclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackIncludeFiles.insert(reg).second) + rebuildUserTracebackIncludeRegex = true; + if (userTracebackExcludeFiles.count(reg)) { + if (userTracebackExcludeFiles.erase(reg)) + rebuildUserTracebackExcludeRegex = true; + } +} + +void PyGlobals::TracebackLoc::registerTracebackFileExclusion( + const std::string &file) { + nanobind::ft_lock_guard lock(mutex); + auto reg = "^" + llvm::Regex::escape(file); + if (userTracebackExcludeFiles.insert(reg).second) + rebuildUserTracebackExcludeRegex = true; + if (userTracebackIncludeFiles.count(reg)) { + if (userTracebackIncludeFiles.erase(reg)) + rebuildUserTracebackIncludeRegex = true; + } +} + +bool PyGlobals::TracebackLoc::isUserTracebackFilename( + const llvm::StringRef file) { + nanobind::ft_lock_guard lock(mutex); + if (rebuildUserTracebackIncludeRegex) { + userTracebackIncludeRegex.assign( + llvm::join(userTracebackIncludeFiles, "|")); + rebuildUserTracebackIncludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (rebuildUserTracebackExcludeRegex) { + userTracebackExcludeRegex.assign( + llvm::join(userTracebackExcludeFiles, "|")); + rebuildUserTracebackExcludeRegex = false; + isUserTracebackFilenameCache.clear(); + } + if (!isUserTracebackFilenameCache.contains(file)) { + std::string fileStr = file.str(); + bool include = std::regex_search(fileStr, userTracebackIncludeRegex); + bool exclude = std::regex_search(fileStr, userTracebackExcludeRegex); + isUserTracebackFilenameCache[file] = include || !exclude; + } + return isUserTracebackFilenameCache[file]; +} diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 9c22dea15..fa16ae3ce 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -192,16 +192,6 @@ class PyMlirContext { PyMlirContext(const PyMlirContext &) = delete; PyMlirContext(PyMlirContext &&) = delete; - /// For the case of a python __init__ (nanobind::init) method, pybind11 is - /// quite strict about needing to return a pointer that is not yet associated - /// to an nanobind::object. Since the forContext() method acts like a pool, - /// possibly returning a recycled context, it does not satisfy this need. The - /// usual way in python to accomplish such a thing is to override __new__, but - /// that is also not supported by pybind11. Instead, we use this entry - /// point which always constructs a fresh context (which cannot alias an - /// existing one because it is fresh). - static PyMlirContext *createNewContextForInit(); - /// Returns a context reference for the singleton PyMlirContext wrapper for /// the given context. static PyMlirContextRef forContext(MlirContext context); @@ -722,8 +712,7 @@ class PyOperation : public PyOperationBase, public BaseContextObject { llvm::ArrayRef operands, std::optional attributes, std::optional> successors, int regions, - DefaultingPyLocation location, const nanobind::object &ip, - bool inferType); + PyLocation &location, const nanobind::object &ip, bool inferType); /// Creates an OpView suitable for this operation. nanobind::object createOpView(); @@ -781,7 +770,7 @@ class PyOpView : public PyOperationBase { nanobind::list operandList, std::optional attributes, std::optional> successors, - std::optional regions, DefaultingPyLocation location, + std::optional regions, PyLocation &location, const nanobind::object &maybeIp); /// Construct an instance of a class deriving from OpView, bypassing its diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 6f4943100..278847e7a 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -6,7 +6,6 @@ // //===----------------------------------------------------------------------===// - #include "Globals.h" #include "IRModule.h" #include "NanobindUtils.h" @@ -44,7 +43,27 @@ NB_MODULE(_mlir, m) { .def("_register_operation_impl", &PyGlobals::registerOperationImpl, "operation_name"_a, "operation_class"_a, nb::kw_only(), "replace"_a = false, - "Testing hook for directly registering an operation"); + "Testing hook for directly registering an operation") + .def("loc_tracebacks_enabled", + [](PyGlobals &self) { + return self.getTracebackLoc().locTracebacksEnabled(); + }) + .def("set_loc_tracebacks_enabled", + [](PyGlobals &self, bool enabled) { + self.getTracebackLoc().setLocTracebacksEnabled(enabled); + }) + .def("set_loc_tracebacks_frame_limit", + [](PyGlobals &self, int n) { + self.getTracebackLoc().setLocTracebackFramesLimit(n); + }) + .def("register_traceback_file_inclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileInclusion(filename); + }) + .def("register_traceback_file_exclusion", + [](PyGlobals &self, const std::string &filename) { + self.getTracebackLoc().registerTracebackFileExclusion(filename); + }); // Aside from making the globals accessible to python, having python manage // it is necessary to make sure it is destroyed (and releases its python diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index a5efa057c..10abd06ff 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -78,12 +78,12 @@ def equally_sized_accessor( def get_default_loc_context(location=None): """ Returns a context in which the defaulted location is created. If the location - is None, takes the current location from the stack, raises ValueError if there - is no location on the stack. + is None, takes the current location from the stack. """ if location is None: - # Location.current raises ValueError if there is no current location. - return _cext.ir.Location.current.context + if _cext.ir.Location.current: + return _cext.ir.Location.current.context + return None return location.context From 088b3db346f6a365bbdb997d9d9854064ddeb318 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 12 Aug 2025 21:16:04 -0400 Subject: [PATCH 911/915] [mlir][python] fix PyThreadState_GetFrame (#153325) `PyThreadState_GetFrame` wasn't added until 3.9 (fixes currently failing rocm builder) --- mlir/lib/Bindings/Python/IRCore.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 1aec7ffe8..03b04ffbe 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2786,6 +2786,14 @@ class PyOpAttributeMap { PyOperationRef operation; }; +// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +#if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) +static inline PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { + assert(tstate != _Py_NULL); + return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); +} +#endif + MlirLocation tracebackToLocation(MlirContext ctx) { size_t framesLimit = PyGlobals::get().getTracebackLoc().locTracebackFramesLimit(); From 4ab3e121a331773c447f7535877f475a08daaeab Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 12 Aug 2025 22:29:23 -0400 Subject: [PATCH 912/915] [mlir][python] fix PyThreadState_GetFrame again (#153333) add more APIs missing from 3.8 (fix rocm builder) --- mlir/lib/Bindings/Python/IRCore.cpp | 70 ++++++++++++++++++++++++++--- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 03b04ffbe..390cdc542 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2786,13 +2786,71 @@ class PyOpAttributeMap { PyOperationRef operation; }; -// bpo-40429 added PyThreadState_GetFrame() to Python 3.9.0b1 +// see +// https://raw.githubusercontent.com/python/pythoncapi_compat/master/pythoncapi_compat.h + +#ifndef _Py_CAST +#define _Py_CAST(type, expr) ((type)(expr)) +#endif + +// Static inline functions should use _Py_NULL rather than using directly NULL +// to prevent C++ compiler warnings. On C23 and newer and on C++11 and newer, +// _Py_NULL is defined as nullptr. +#ifndef _Py_NULL +#if (defined(__STDC_VERSION__) && __STDC_VERSION__ > 201710L) || \ + (defined(__cplusplus) && __cplusplus >= 201103) +#define _Py_NULL nullptr +#else +#define _Py_NULL NULL +#endif +#endif + +// Python 3.10.0a3 +#if PY_VERSION_HEX < 0x030A00A3 + +// bpo-42262 added Py_XNewRef() +#if !defined(Py_XNewRef) +PyObject *_Py_XNewRef(PyObject *obj) { + Py_XINCREF(obj); + return obj; +} +#define Py_XNewRef(obj) _Py_XNewRef(_PyObject_CAST(obj)) +#endif + +// bpo-42262 added Py_NewRef() +#if !defined(Py_NewRef) +PyObject *_Py_NewRef(PyObject *obj) { + Py_INCREF(obj); + return obj; +} +#define Py_NewRef(obj) _Py_NewRef(_PyObject_CAST(obj)) +#endif + +#endif // Python 3.10.0a3 + +// Python 3.9.0b1 #if PY_VERSION_HEX < 0x030900B1 && !defined(PYPY_VERSION) -static inline PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { - assert(tstate != _Py_NULL); + +// bpo-40429 added PyThreadState_GetFrame() +PyFrameObject *PyThreadState_GetFrame(PyThreadState *tstate) { + assert(tstate != _Py_NULL && "expected tstate != _Py_NULL"); return _Py_CAST(PyFrameObject *, Py_XNewRef(tstate->frame)); } -#endif + +// bpo-40421 added PyFrame_GetBack() +PyFrameObject *PyFrame_GetBack(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + return _Py_CAST(PyFrameObject *, Py_XNewRef(frame->f_back)); +} + +// bpo-40421 added PyFrame_GetCode() +PyCodeObject *PyFrame_GetCode(PyFrameObject *frame) { + assert(frame != _Py_NULL && "expected frame != _Py_NULL"); + assert(frame->f_code != _Py_NULL && "expected frame->f_code != _Py_NULL"); + return _Py_CAST(PyCodeObject *, Py_NewRef(frame->f_code)); +} + +#endif // Python 3.9.0b1 MlirLocation tracebackToLocation(MlirContext ctx) { size_t framesLimit = @@ -2820,7 +2878,8 @@ MlirLocation tracebackToLocation(MlirContext ctx) { if (!PyGlobals::get().getTracebackLoc().isUserTracebackFilename(fileName)) continue; -#if PY_VERSION_HEX < 0x030b00f0 + // co_qualname and PyCode_Addr2Location added in py3.11 +#if PY_VERSION_HEX < 0x030B00F0 std::string name = nb::cast(nb::borrow(code->co_name)); llvm::StringRef funcName(name); @@ -2828,7 +2887,6 @@ MlirLocation tracebackToLocation(MlirContext ctx) { MlirLocation loc = mlirLocationFileLineColGet(ctx, wrap(fileName), startLine, 0); #else - // co_qualname and PyCode_Addr2Location added in py3.11 std::string name = nb::cast(nb::borrow(code->co_qualname)); llvm::StringRef funcName(name); From cd36d1e7677515bb45f7bcd2c59fb6c6360721e3 Mon Sep 17 00:00:00 2001 From: Shenghang Tsai Date: Wed, 13 Aug 2025 21:22:01 +0800 Subject: [PATCH 913/915] [MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call (#153373) This PR introduces a mechanism to defer JIT engine initialization, enabling registration of required symbols before global constructor execution. ## Problem Modules containing `gpu.module` generate global constructors (e.g., kernel load/unload) that execute *during* engine creation. This can force premature symbol resolution, causing failures when: - Symbols are registered via `mlirExecutionEngineRegisterSymbol` *after* creation - Global constructors exist (even if not directly using unresolved symbols, e.g., an external function declaration) - GPU modules introduce mandatory binary loading logic ## Usage ```c // Create engine without initialization MlirExecutionEngine jit = mlirExecutionEngineCreate(...); // Register required symbols mlirExecutionEngineRegisterSymbol(jit, ...); // Explicitly initialize (runs global constructors) mlirExecutionEngineInitialize(jit); ``` --------- Co-authored-by: Mehdi Amini --- mlir/include/mlir-c/ExecutionEngine.h | 7 +++++++ mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 13 ++++++++++++- mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 9 ++++++--- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 99cddc5c2..1a58d6853 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -46,6 +46,13 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump); +/// Initialize the ExecutionEngine. Global constructors specified by +/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel +/// binary compiled from `gpu.module` gets loaded during initialization. Make +/// sure all symbols are resolvable before initialization by calling +/// `mlirExecutionEngineRegisterSymbol` or including shared libraries. +MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit); + /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 81dada355..4f7a4a628 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/NanobindAdaptors.h" #include "mlir/Bindings/Python/Nanobind.h" +#include "mlir/Bindings/Python/NanobindAdaptors.h" namespace nb = nanobind; using namespace mlir; @@ -124,6 +124,17 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") + .def( + "initialize", + [](PyExecutionEngine &executionEngine) { + mlirExecutionEngineInitialize(executionEngine.get()); + }, + "Initialize the ExecutionEngine. Global constructors specified by " + "`llvm.mlir.global_ctors` will be run. One common scenario is that " + "kernel binary compiled from `gpu.module` gets loaded during " + "initialization. Make sure all symbols are resolvable before " + "initialization by calling `raw_register_runtime` or including " + "shared libraries.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 306cebd23..2dbb993b1 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,6 +68,10 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } +extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { + unwrap(jit)->initialize(); +} + extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -106,9 +110,8 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = - { llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported }; + symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported}; return symbolMap; }); } From 7b0f329a472892a71ff97bed2fba4e26839eac27 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 13 Aug 2025 21:43:04 +0200 Subject: [PATCH 914/915] Revert "[MLIR] Split ExecutionEngine Initialization out of ctor into an explicit method call" (#153477) Reverts llvm/llvm-project#153373 Sanitizer bot is broken --- mlir/include/mlir-c/ExecutionEngine.h | 7 ------- mlir/lib/Bindings/Python/ExecutionEngineModule.cpp | 13 +------------ mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp | 9 +++------ 3 files changed, 4 insertions(+), 25 deletions(-) diff --git a/mlir/include/mlir-c/ExecutionEngine.h b/mlir/include/mlir-c/ExecutionEngine.h index 1a58d6853..99cddc5c2 100644 --- a/mlir/include/mlir-c/ExecutionEngine.h +++ b/mlir/include/mlir-c/ExecutionEngine.h @@ -46,13 +46,6 @@ MLIR_CAPI_EXPORTED MlirExecutionEngine mlirExecutionEngineCreate( MlirModule op, int optLevel, int numPaths, const MlirStringRef *sharedLibPaths, bool enableObjectDump); -/// Initialize the ExecutionEngine. Global constructors specified by -/// `llvm.mlir.global_ctors` will be run. One common scenario is that kernel -/// binary compiled from `gpu.module` gets loaded during initialization. Make -/// sure all symbols are resolvable before initialization by calling -/// `mlirExecutionEngineRegisterSymbol` or including shared libraries. -MLIR_CAPI_EXPORTED void mlirExecutionEngineInitialize(MlirExecutionEngine jit); - /// Destroy an ExecutionEngine instance. MLIR_CAPI_EXPORTED void mlirExecutionEngineDestroy(MlirExecutionEngine jit); diff --git a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp index 4f7a4a628..81dada355 100644 --- a/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp +++ b/mlir/lib/Bindings/Python/ExecutionEngineModule.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir-c/ExecutionEngine.h" -#include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" +#include "mlir/Bindings/Python/Nanobind.h" namespace nb = nanobind; using namespace mlir; @@ -124,17 +124,6 @@ NB_MODULE(_mlirExecutionEngine, m) { }, nb::arg("name"), nb::arg("callback"), "Register `callback` as the runtime symbol `name`.") - .def( - "initialize", - [](PyExecutionEngine &executionEngine) { - mlirExecutionEngineInitialize(executionEngine.get()); - }, - "Initialize the ExecutionEngine. Global constructors specified by " - "`llvm.mlir.global_ctors` will be run. One common scenario is that " - "kernel binary compiled from `gpu.module` gets loaded during " - "initialization. Make sure all symbols are resolvable before " - "initialization by calling `raw_register_runtime` or including " - "shared libraries.") .def( "dump_to_object_file", [](PyExecutionEngine &executionEngine, const std::string &fileName) { diff --git a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp index 2dbb993b1..306cebd23 100644 --- a/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp @@ -68,10 +68,6 @@ mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths, return wrap(jitOrError->release()); } -extern "C" void mlirExecutionEngineInitialize(MlirExecutionEngine jit) { - unwrap(jit)->initialize(); -} - extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) { delete (unwrap(jit)); } @@ -110,8 +106,9 @@ extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit, void *sym) { unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) { llvm::orc::SymbolMap symbolMap; - symbolMap[interner(unwrap(name))] = {llvm::orc::ExecutorAddr::fromPtr(sym), - llvm::JITSymbolFlags::Exported}; + symbolMap[interner(unwrap(name))] = + { llvm::orc::ExecutorAddr::fromPtr(sym), + llvm::JITSymbolFlags::Exported }; return symbolMap; }); } From 22c366893034dee083da200ce940bc6e7b9b9d56 Mon Sep 17 00:00:00 2001 From: makslevental Date: Thu, 14 Aug 2025 15:01:59 -0400 Subject: [PATCH 915/915] create fork --- .github/workflows/integrate_llvm.yml | 112 +++++++++++++++ .gitignore | 2 + LICENSE | 201 +++++++++++++++++++++++++++ README.md | 44 ++++++ emreg.py | 11 ++ filter-llvm.sh | 23 +++ 6 files changed, 393 insertions(+) create mode 100644 .github/workflows/integrate_llvm.yml create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 README.md create mode 100644 emreg.py create mode 100755 filter-llvm.sh diff --git a/.github/workflows/integrate_llvm.yml b/.github/workflows/integrate_llvm.yml new file mode 100644 index 000000000..85ece9095 --- /dev/null +++ b/.github/workflows/integrate_llvm.yml @@ -0,0 +1,112 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Copyright (c) 2024. + +name: Auto Integrate LLVM +on: + workflow_dispatch: + pull_request: + push: + branches: + - main + schedule: + # At minute 0 past hour 1. (see https://crontab.guru) + - cron: '00 01 * * *' + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). The workflow name is prepended to avoid conflicts between + # different workflows. + group: ${{ github.workflow }}-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + update-dep: + + name: "Integrate LLVM and send PR" + + runs-on: ubuntu-latest + + permissions: + contents: write + pull-requests: write + + steps: + - name: "Check out repository" + uses: actions/checkout@v4.2.2 + + - name: Cache LLVM clone + id: cache-llvm + uses: actions/cache@v4 + with: + path: /tmp/llvm-project + key: cache-llvm-project + + - name: "Get filtered llvm-project" + shell: bash + id: get-llvm-project + run: | + + sudo apt install git-filter-repo + + HERE=$(pwd) + + pushd /tmp + echo "cache-hit ${{ steps.cache-llvm.outputs.cache-hit }}" + # https://github.com/actions/cache/issues/1566 + if [ "${{ steps.cache-llvm.outputs.cache-hit }}" == "" ]; then + git clone https://github.com/llvm/llvm-project.git -v + fi + + pushd llvm-project + + git pull origin main + echo "LLVM_SHA_SHORT=$(git rev-parse --short HEAD)" >> $GITHUB_OUTPUT + bash $HERE/filter-llvm.sh + + popd + popd + + - name: "Rebase on top of llvm-project" + shell: bash + id: rebase-llvm-project + run: | + + git config user.email "github-actions[bot]@users.noreply.github.com" + git config user.name "github-actions[bot]" + git pull /tmp/llvm-project main --rebase + + - name: Generate token + uses: actions/create-github-app-token@v1 + id: generate-token + with: + app-id: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_ID }} + private-key: ${{ secrets.BUMP_LLVM_CREATE_PR_APP_PRIVATE_KEY }} + + - name: "Create Pull Request" + id: cpr + uses: peter-evans/create-pull-request@v7 + with: + token: ${{ steps.generate-token.outputs.token }} + commit-message: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + title: "[LLVM] Integrate to ${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + body: "Integrate LLVM to https://github.com/llvm/llvm-project/commit/${{ steps.get-llvm-project.outputs.LLVM_SHA_SHORT }}" + author: 'github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>' + base: main + branch: update-llvm + delete-branch: true + + - name: Enable auto-merge + if: steps.cpr.outputs.pull-request-operation == 'created' + uses: peter-evans/enable-pull-request-automerge@v3 + with: + token: ${{ steps.generate-token.outputs.token }} + pull-request-number: ${{ steps.cpr.outputs.pull-request-number }} + + - name: Auto approve + if: steps.cpr.outputs.pull-request-operation == 'created' + run: gh pr review --approve "${{ steps.cpr.outputs.pull-request-number }}" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..1b17bee17 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +.idea +llvm-project \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..261eeb9e9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 000000000..d73ad500a --- /dev/null +++ b/README.md @@ -0,0 +1,44 @@ +# TL;DR: + +In this repo: + +``` +$ git clone git@github.com:llvm/llvm-project.git +$ cd llvm-project + +IFS= read -r -d '' x [-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\"([]!#-[^-~ \t]|(\\[\t -~]))+\")@(?P[-!#-'*+/-9=?A-Z^-~]+(\.[-!#-'*+/-9=?A-Z^-~]+)*|\[[\t -Z^-~]*])"); +rr = re.compile(r"@(?P[a-z\d](?:[a-z\d]|-(?=[a-z\d])){0,38})(?=\b)"); + +message = message.decode("utf-8"); +message = r.sub(r"\g$$$\g", message); +message = rr.sub(r"\g", message); +message = message.replace("$$$", "@"); + +return message.encode("utf-8") diff --git a/filter-llvm.sh b/filter-llvm.sh new file mode 100755 index 000000000..bda32d49b --- /dev/null +++ b/filter-llvm.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash + +#rm -rf .git +#rm -rf mlir +#git init +#git checkout -b main +#git remote add -f origin /tmp/llvm-project +#git pull origin main + +IFS= read -r -d '' x